From 312cdc5694324956cdb8308e470a80bd06e9a671 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Fri, 31 Mar 2023 23:08:27 -0600 Subject: [PATCH] support multiple models --- config.sample.yaml | 4 +- main.py | 12 ++- matrix_gpt/bot/callbacks.py | 147 +++++++++++-------------------- matrix_gpt/bot/chat_functions.py | 15 +++- 4 files changed, 77 insertions(+), 101 deletions(-) diff --git a/config.sample.yaml b/config.sample.yaml index e432b8e..14baa1b 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -28,7 +28,9 @@ autojoin_rooms: # Should the bot set its avatar on login? #set_avatar: true -command_prefix: '!c' +command: + gpt3_prefix: '!c3' + gpt4_prefix: '!c4' # optional reply_in_thread: true diff --git a/main.py b/main.py index ec5877d..f0c4701 100755 --- a/main.py +++ b/main.py @@ -48,8 +48,8 @@ check_config_value_exists(config_data['bot_auth'], 'homeserver') check_config_value_exists(config_data['bot_auth'], 'store_path') check_config_value_exists(config_data, 'allowed_to_chat') check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True) -check_config_value_exists(config_data, 'command_prefix') check_config_value_exists(config_data, 'data_storage') +check_config_value_exists(config_data, 'command') check_config_value_exists(config_data, 'logging') check_config_value_exists(config_data['logging'], 'log_level') @@ -58,6 +58,13 @@ check_config_value_exists(config_data, 'openai') check_config_value_exists(config_data['openai'], 'api_key') check_config_value_exists(config_data['openai'], 'model') +gpt4_enabled = True if config_data['command'].get('gpt4_prefix') else False +logger.info(f'GPT4 enabled? {gpt4_enabled}') + +command_prefixes = {} +for k, v in config_data['command'].items(): + command_prefixes[k] = v + # check_config_value_exists(config_data, 'autojoin_rooms') @@ -102,7 +109,7 @@ async def main(): # Set up event callbacks callbacks = Callbacks(client, storage, openai_obj=openai, - command_prefix=config_data['command_prefix'], + command_prefixes=command_prefixes, openai_model=config_data['openai']['model'], reply_in_thread=config_data.get('reply_in_thread', False), allowed_to_invite=config_data['allowed_to_invite'], @@ -111,6 +118,7 @@ async def main(): system_prompt=config_data['openai'].get('system_prompt'), injected_system_prompt=config_data['openai'].get('injected_system_prompt', False), openai_temperature=config_data['openai'].get('temperature', 0), + gpt4_enabled=gpt4_enabled, log_level=log_level ) client.add_event_callback(callbacks.message, RoomMessageText) diff --git a/matrix_gpt/bot/callbacks.py b/matrix_gpt/bot/callbacks.py index 01323db..313289e 100644 --- a/matrix_gpt/bot/callbacks.py +++ b/matrix_gpt/bot/callbacks.py @@ -5,10 +5,10 @@ import logging import time from types import ModuleType -from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent, ) +from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent) from .bot_commands import Command -from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room +from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room, check_command_prefix # from .config import Config from .storage import Storage @@ -19,7 +19,7 @@ class Callbacks: def __init__(self, client: AsyncClient, store: Storage, - command_prefix: str, + command_prefixes: dict, openai_obj: ModuleType, openai_model: str, reply_in_thread: bool, @@ -29,20 +29,19 @@ class Callbacks: log_full_response: bool = False, injected_system_prompt: str = False, openai_temperature: float = 0, + gpt4_enabled: bool = False, log_level=logging.INFO ): """ Args: client: nio client used to interact with matrix. - store: Bot storage. - config: Bot configuration parameters. """ self.client = client self.store = store # self.config = config - self.command_prefix = command_prefix + self.command_prefixes = command_prefixes self.openai_model = openai_model self.startup_ts = time.time_ns() // 1_000_000 self.reply_in_thread = reply_in_thread @@ -53,6 +52,7 @@ class Callbacks: self.injected_system_prompt = injected_system_prompt self.openai_obj = openai_obj self.openai_temperature = openai_temperature + self.gpt4_enabled = gpt4_enabled self.log_level = log_level async def message(self, room: MatrixRoom, event: RoomMessageText) -> None: @@ -60,7 +60,6 @@ class Callbacks: Args: room: The room the event came from. - event: The event defining the message. """ # Extract the message text @@ -89,53 +88,54 @@ class Callbacks: # else: # has_command_prefix = False - # room.is_group is often a DM, but not always. - # room.is_group does not allow room aliases - # room.member_count > 2 ... we assume a public room - # room.member_count <= 2 ... we assume a DM + command_activated, selected_model, sent_command_prefix = check_command_prefix(msg, self.command_prefixes) + # General message listener - if not msg.startswith(f'{self.command_prefix} ') and is_thread(event) and (await is_this_our_thread(self.client, room, event, self.command_prefix) or room.member_count == 2): - await self.client.room_typing(room.room_id, typing_state=True, timeout=3000) - thread_content = await get_thread_content(self.client, room, event) - api_data = [] - for event in thread_content: - if isinstance(event, MegolmEvent): - resp = await send_text_to_room(self.client, room.room_id, '❌ 🔐 Decryption Failure', reply_to_event_id=event.event_id, thread=True, thread_root_id=thread_content[0].event_id) - logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}') - await self.client.room_typing(room.room_id, typing_state=False, timeout=3000) - self.store.add_event_id(resp.event_id) - return - else: - thread_msg = event.body.strip().strip('\n') - api_data.append( - { - 'role': 'assistant' if event.sender == self.client.user_id else 'user', - 'content': thread_msg if not thread_msg.startswith(self.command_prefix) else thread_msg[len(self.command_prefix):].strip() - }) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user + if not command_activated and is_thread(event): + is_our_thread, selected_model, sent_command_prefix = await is_this_our_thread(self.client, room, event, self.command_prefixes) + if is_our_thread or room.member_count == 2: + await self.client.room_typing(room.room_id, typing_state=True, timeout=3000) + thread_content = await get_thread_content(self.client, room, event) + api_data = [] + for event in thread_content: + if isinstance(event, MegolmEvent): + resp = await send_text_to_room(self.client, room.room_id, '❌ 🔐 Decryption Failure', reply_to_event_id=event.event_id, thread=True, thread_root_id=thread_content[0].event_id) + logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}') + await self.client.room_typing(room.room_id, typing_state=False, timeout=3000) + self.store.add_event_id(resp.event_id) + return + else: + thread_msg = event.body.strip().strip('\n') + api_data.append( + { + 'role': 'assistant' if event.sender == self.client.user_id else 'user', + 'content': thread_msg if not check_command_prefix(thread_msg, self.command_prefixes) else thread_msg[len(self.command_prefixes):].strip() + } + ) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user - # TODO: process_chat() will set typing as false after generating. - # TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished. - async def inner(): - await process_chat( - self.client, - room, - event, - api_data, - self.store, - openai_obj=self.openai_obj, - openai_model=self.openai_model, - openai_temperature=self.openai_temperature, - thread_root_id=thread_content[0].event_id, - system_prompt=self.system_prompt, - log_full_response=self.log_full_response, - injected_system_prompt=self.injected_system_prompt - ) + # TODO: process_chat() will set typing as false after generating. + # TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished. + async def inner(): + await process_chat( + self.client, + room, + event, + api_data, + self.store, + openai_obj=self.openai_obj, + openai_model=selected_model, + openai_temperature=self.openai_temperature, + thread_root_id=thread_content[0].event_id, + system_prompt=self.system_prompt, + log_full_response=self.log_full_response, + injected_system_prompt=self.injected_system_prompt + ) - asyncio.get_event_loop().create_task(inner()) + asyncio.get_event_loop().create_task(inner()) return - elif (msg.startswith(f'{self.command_prefix} ') or room.member_count == 2) and not is_thread(event): + elif (command_activated or room.member_count == 2) and not is_thread(event): # Otherwise if this is in a 1-1 with the bot or features a command prefix, treat it as a command. - msg = msg if not msg.startswith(self.command_prefix) else msg[len(self.command_prefix):].strip() # Remove the command prefix + msg = msg if not command_activated else msg[len(self.command_prefixes):].strip() # Remove the command prefix command = Command( self.client, self.store, @@ -143,7 +143,7 @@ class Callbacks: room, event, openai_obj=self.openai_obj, - openai_model=self.openai_model, + openai_model=selected_model, openai_temperature=self.openai_temperature, reply_in_thread=self.reply_in_thread, system_prompt=self.system_prompt, @@ -155,10 +155,11 @@ class Callbacks: if self.log_level == logging.DEBUG: # This may be a little slow debug = { - 'command_prefix': msg.startswith(f'{self.command_prefix} '), + 'command_prefix': sent_command_prefix, + 'are_we_activated': command_activated, 'is_dm': room.member_count == 2, 'is_thread': is_thread(event), - 'is_our_thread': await is_this_our_thread(self.client, room, event, self.command_prefix) + 'is_our_thread': await is_this_our_thread(self.client, room, event, self.command_prefixes)[0] } logger.debug(f"Bot not reacting to event {event.event_id}: {json.dumps(debug)}") @@ -197,48 +198,6 @@ class Callbacks: if event.state_key == self.client.user_id: await self.invite(room, event) - # async def _reaction( - # self, room: MatrixRoom, event: UnknownEvent, reacted_to_id: str - # ) -> None: - # """A reaction was sent to one of our messages. Let's send a reply acknowledging it. - # - # Args: - # room: The room the reaction was sent in. - # - # event: The reaction event. - # - # reacted_to_id: The event ID that the reaction points to. - # """ - # logger.debug(f"Got reaction to {room.room_id} from {event.sender}.") - # - # # Get the original event that was reacted to - # event_response = await self.client.room_get_event(room.room_id, reacted_to_id) - # if isinstance(event_response, RoomGetEventError): - # logger.warning( - # "Error getting event that was reacted to (%s)", reacted_to_id - # ) - # return - # reacted_to_event = event_response.event - # - # # Only acknowledge reactions to events that we sent - # if reacted_to_event.sender != self.config.user_id: - # return - # - # # Send a message acknowledging the reaction - # reaction_sender_pill = make_pill(event.sender) - # reaction_content = ( - # event.source.get("content", {}).get("m.relates_to", {}).get("key") - # ) - # message = ( - # f"{reaction_sender_pill} reacted to this event with `{reaction_content}`!" - # ) - # await send_text_to_room( - # self.client, - # room.room_id, - # message, - # reply_to_event_id=reacted_to_id, - # ) - async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None: """Callback for when an event fails to decrypt. Inform the user. diff --git a/matrix_gpt/bot/chat_functions.py b/matrix_gpt/bot/chat_functions.py index eed073c..82d21ca 100644 --- a/matrix_gpt/bot/chat_functions.py +++ b/matrix_gpt/bot/chat_functions.py @@ -112,17 +112,24 @@ def is_thread(event: RoomMessageText): return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread' -async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_flag: str): +def check_command_prefix(string: str, prefixes: dict): + for k, v in prefixes.items(): + if string.startswith(f'{v} '): + return True, k, v + return False, None, None + + +async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_prefixes: dict) -> tuple[bool, any, any]: base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id') if base_event_id: e = await client.room_get_event(room.room_id, base_event_id) if not isinstance(e, RoomGetEventResponse): logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}') - return + return False, None, None else: - return e.event.body.startswith(f'{command_flag} ') + return check_command_prefix(e.event.body, command_prefixes) else: - return False + return False, None, None async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]: