support multiple models
This commit is contained in:
parent
0588bc3f53
commit
312cdc5694
|
@ -28,7 +28,9 @@ autojoin_rooms:
|
||||||
# Should the bot set its avatar on login?
|
# Should the bot set its avatar on login?
|
||||||
#set_avatar: true
|
#set_avatar: true
|
||||||
|
|
||||||
command_prefix: '!c'
|
command:
|
||||||
|
gpt3_prefix: '!c3'
|
||||||
|
gpt4_prefix: '!c4' # optional
|
||||||
|
|
||||||
reply_in_thread: true
|
reply_in_thread: true
|
||||||
|
|
||||||
|
|
12
main.py
12
main.py
|
@ -48,8 +48,8 @@ check_config_value_exists(config_data['bot_auth'], 'homeserver')
|
||||||
check_config_value_exists(config_data['bot_auth'], 'store_path')
|
check_config_value_exists(config_data['bot_auth'], 'store_path')
|
||||||
check_config_value_exists(config_data, 'allowed_to_chat')
|
check_config_value_exists(config_data, 'allowed_to_chat')
|
||||||
check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True)
|
check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True)
|
||||||
check_config_value_exists(config_data, 'command_prefix')
|
|
||||||
check_config_value_exists(config_data, 'data_storage')
|
check_config_value_exists(config_data, 'data_storage')
|
||||||
|
check_config_value_exists(config_data, 'command')
|
||||||
|
|
||||||
check_config_value_exists(config_data, 'logging')
|
check_config_value_exists(config_data, 'logging')
|
||||||
check_config_value_exists(config_data['logging'], 'log_level')
|
check_config_value_exists(config_data['logging'], 'log_level')
|
||||||
|
@ -58,6 +58,13 @@ check_config_value_exists(config_data, 'openai')
|
||||||
check_config_value_exists(config_data['openai'], 'api_key')
|
check_config_value_exists(config_data['openai'], 'api_key')
|
||||||
check_config_value_exists(config_data['openai'], 'model')
|
check_config_value_exists(config_data['openai'], 'model')
|
||||||
|
|
||||||
|
gpt4_enabled = True if config_data['command'].get('gpt4_prefix') else False
|
||||||
|
logger.info(f'GPT4 enabled? {gpt4_enabled}')
|
||||||
|
|
||||||
|
command_prefixes = {}
|
||||||
|
for k, v in config_data['command'].items():
|
||||||
|
command_prefixes[k] = v
|
||||||
|
|
||||||
|
|
||||||
# check_config_value_exists(config_data, 'autojoin_rooms')
|
# check_config_value_exists(config_data, 'autojoin_rooms')
|
||||||
|
|
||||||
|
@ -102,7 +109,7 @@ async def main():
|
||||||
# Set up event callbacks
|
# Set up event callbacks
|
||||||
callbacks = Callbacks(client, storage,
|
callbacks = Callbacks(client, storage,
|
||||||
openai_obj=openai,
|
openai_obj=openai,
|
||||||
command_prefix=config_data['command_prefix'],
|
command_prefixes=command_prefixes,
|
||||||
openai_model=config_data['openai']['model'],
|
openai_model=config_data['openai']['model'],
|
||||||
reply_in_thread=config_data.get('reply_in_thread', False),
|
reply_in_thread=config_data.get('reply_in_thread', False),
|
||||||
allowed_to_invite=config_data['allowed_to_invite'],
|
allowed_to_invite=config_data['allowed_to_invite'],
|
||||||
|
@ -111,6 +118,7 @@ async def main():
|
||||||
system_prompt=config_data['openai'].get('system_prompt'),
|
system_prompt=config_data['openai'].get('system_prompt'),
|
||||||
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
||||||
openai_temperature=config_data['openai'].get('temperature', 0),
|
openai_temperature=config_data['openai'].get('temperature', 0),
|
||||||
|
gpt4_enabled=gpt4_enabled,
|
||||||
log_level=log_level
|
log_level=log_level
|
||||||
)
|
)
|
||||||
client.add_event_callback(callbacks.message, RoomMessageText)
|
client.add_event_callback(callbacks.message, RoomMessageText)
|
||||||
|
|
|
@ -5,10 +5,10 @@ import logging
|
||||||
import time
|
import time
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent, )
|
from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent)
|
||||||
|
|
||||||
from .bot_commands import Command
|
from .bot_commands import Command
|
||||||
from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room
|
from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room, check_command_prefix
|
||||||
# from .config import Config
|
# from .config import Config
|
||||||
from .storage import Storage
|
from .storage import Storage
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class Callbacks:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
client: AsyncClient,
|
client: AsyncClient,
|
||||||
store: Storage,
|
store: Storage,
|
||||||
command_prefix: str,
|
command_prefixes: dict,
|
||||||
openai_obj: ModuleType,
|
openai_obj: ModuleType,
|
||||||
openai_model: str,
|
openai_model: str,
|
||||||
reply_in_thread: bool,
|
reply_in_thread: bool,
|
||||||
|
@ -29,20 +29,19 @@ class Callbacks:
|
||||||
log_full_response: bool = False,
|
log_full_response: bool = False,
|
||||||
injected_system_prompt: str = False,
|
injected_system_prompt: str = False,
|
||||||
openai_temperature: float = 0,
|
openai_temperature: float = 0,
|
||||||
|
gpt4_enabled: bool = False,
|
||||||
log_level=logging.INFO
|
log_level=logging.INFO
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
client: nio client used to interact with matrix.
|
client: nio client used to interact with matrix.
|
||||||
|
|
||||||
store: Bot storage.
|
store: Bot storage.
|
||||||
|
|
||||||
config: Bot configuration parameters.
|
config: Bot configuration parameters.
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.store = store
|
self.store = store
|
||||||
# self.config = config
|
# self.config = config
|
||||||
self.command_prefix = command_prefix
|
self.command_prefixes = command_prefixes
|
||||||
self.openai_model = openai_model
|
self.openai_model = openai_model
|
||||||
self.startup_ts = time.time_ns() // 1_000_000
|
self.startup_ts = time.time_ns() // 1_000_000
|
||||||
self.reply_in_thread = reply_in_thread
|
self.reply_in_thread = reply_in_thread
|
||||||
|
@ -53,6 +52,7 @@ class Callbacks:
|
||||||
self.injected_system_prompt = injected_system_prompt
|
self.injected_system_prompt = injected_system_prompt
|
||||||
self.openai_obj = openai_obj
|
self.openai_obj = openai_obj
|
||||||
self.openai_temperature = openai_temperature
|
self.openai_temperature = openai_temperature
|
||||||
|
self.gpt4_enabled = gpt4_enabled
|
||||||
self.log_level = log_level
|
self.log_level = log_level
|
||||||
|
|
||||||
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
@ -60,7 +60,6 @@ class Callbacks:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room: The room the event came from.
|
room: The room the event came from.
|
||||||
|
|
||||||
event: The event defining the message.
|
event: The event defining the message.
|
||||||
"""
|
"""
|
||||||
# Extract the message text
|
# Extract the message text
|
||||||
|
@ -89,12 +88,12 @@ class Callbacks:
|
||||||
# else:
|
# else:
|
||||||
# has_command_prefix = False
|
# has_command_prefix = False
|
||||||
|
|
||||||
# room.is_group is often a DM, but not always.
|
command_activated, selected_model, sent_command_prefix = check_command_prefix(msg, self.command_prefixes)
|
||||||
# room.is_group does not allow room aliases
|
|
||||||
# room.member_count > 2 ... we assume a public room
|
|
||||||
# room.member_count <= 2 ... we assume a DM
|
|
||||||
# General message listener
|
# General message listener
|
||||||
if not msg.startswith(f'{self.command_prefix} ') and is_thread(event) and (await is_this_our_thread(self.client, room, event, self.command_prefix) or room.member_count == 2):
|
if not command_activated and is_thread(event):
|
||||||
|
is_our_thread, selected_model, sent_command_prefix = await is_this_our_thread(self.client, room, event, self.command_prefixes)
|
||||||
|
if is_our_thread or room.member_count == 2:
|
||||||
await self.client.room_typing(room.room_id, typing_state=True, timeout=3000)
|
await self.client.room_typing(room.room_id, typing_state=True, timeout=3000)
|
||||||
thread_content = await get_thread_content(self.client, room, event)
|
thread_content = await get_thread_content(self.client, room, event)
|
||||||
api_data = []
|
api_data = []
|
||||||
|
@ -110,8 +109,9 @@ class Callbacks:
|
||||||
api_data.append(
|
api_data.append(
|
||||||
{
|
{
|
||||||
'role': 'assistant' if event.sender == self.client.user_id else 'user',
|
'role': 'assistant' if event.sender == self.client.user_id else 'user',
|
||||||
'content': thread_msg if not thread_msg.startswith(self.command_prefix) else thread_msg[len(self.command_prefix):].strip()
|
'content': thread_msg if not check_command_prefix(thread_msg, self.command_prefixes) else thread_msg[len(self.command_prefixes):].strip()
|
||||||
}) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user
|
}
|
||||||
|
) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user
|
||||||
|
|
||||||
# TODO: process_chat() will set typing as false after generating.
|
# TODO: process_chat() will set typing as false after generating.
|
||||||
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
|
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
|
||||||
|
@ -123,7 +123,7 @@ class Callbacks:
|
||||||
api_data,
|
api_data,
|
||||||
self.store,
|
self.store,
|
||||||
openai_obj=self.openai_obj,
|
openai_obj=self.openai_obj,
|
||||||
openai_model=self.openai_model,
|
openai_model=selected_model,
|
||||||
openai_temperature=self.openai_temperature,
|
openai_temperature=self.openai_temperature,
|
||||||
thread_root_id=thread_content[0].event_id,
|
thread_root_id=thread_content[0].event_id,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
@ -133,9 +133,9 @@ class Callbacks:
|
||||||
|
|
||||||
asyncio.get_event_loop().create_task(inner())
|
asyncio.get_event_loop().create_task(inner())
|
||||||
return
|
return
|
||||||
elif (msg.startswith(f'{self.command_prefix} ') or room.member_count == 2) and not is_thread(event):
|
elif (command_activated or room.member_count == 2) and not is_thread(event):
|
||||||
# Otherwise if this is in a 1-1 with the bot or features a command prefix, treat it as a command.
|
# Otherwise if this is in a 1-1 with the bot or features a command prefix, treat it as a command.
|
||||||
msg = msg if not msg.startswith(self.command_prefix) else msg[len(self.command_prefix):].strip() # Remove the command prefix
|
msg = msg if not command_activated else msg[len(self.command_prefixes):].strip() # Remove the command prefix
|
||||||
command = Command(
|
command = Command(
|
||||||
self.client,
|
self.client,
|
||||||
self.store,
|
self.store,
|
||||||
|
@ -143,7 +143,7 @@ class Callbacks:
|
||||||
room,
|
room,
|
||||||
event,
|
event,
|
||||||
openai_obj=self.openai_obj,
|
openai_obj=self.openai_obj,
|
||||||
openai_model=self.openai_model,
|
openai_model=selected_model,
|
||||||
openai_temperature=self.openai_temperature,
|
openai_temperature=self.openai_temperature,
|
||||||
reply_in_thread=self.reply_in_thread,
|
reply_in_thread=self.reply_in_thread,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
|
@ -155,10 +155,11 @@ class Callbacks:
|
||||||
if self.log_level == logging.DEBUG:
|
if self.log_level == logging.DEBUG:
|
||||||
# This may be a little slow
|
# This may be a little slow
|
||||||
debug = {
|
debug = {
|
||||||
'command_prefix': msg.startswith(f'{self.command_prefix} '),
|
'command_prefix': sent_command_prefix,
|
||||||
|
'are_we_activated': command_activated,
|
||||||
'is_dm': room.member_count == 2,
|
'is_dm': room.member_count == 2,
|
||||||
'is_thread': is_thread(event),
|
'is_thread': is_thread(event),
|
||||||
'is_our_thread': await is_this_our_thread(self.client, room, event, self.command_prefix)
|
'is_our_thread': await is_this_our_thread(self.client, room, event, self.command_prefixes)[0]
|
||||||
|
|
||||||
}
|
}
|
||||||
logger.debug(f"Bot not reacting to event {event.event_id}: {json.dumps(debug)}")
|
logger.debug(f"Bot not reacting to event {event.event_id}: {json.dumps(debug)}")
|
||||||
|
@ -197,48 +198,6 @@ class Callbacks:
|
||||||
if event.state_key == self.client.user_id:
|
if event.state_key == self.client.user_id:
|
||||||
await self.invite(room, event)
|
await self.invite(room, event)
|
||||||
|
|
||||||
# async def _reaction(
|
|
||||||
# self, room: MatrixRoom, event: UnknownEvent, reacted_to_id: str
|
|
||||||
# ) -> None:
|
|
||||||
# """A reaction was sent to one of our messages. Let's send a reply acknowledging it.
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# room: The room the reaction was sent in.
|
|
||||||
#
|
|
||||||
# event: The reaction event.
|
|
||||||
#
|
|
||||||
# reacted_to_id: The event ID that the reaction points to.
|
|
||||||
# """
|
|
||||||
# logger.debug(f"Got reaction to {room.room_id} from {event.sender}.")
|
|
||||||
#
|
|
||||||
# # Get the original event that was reacted to
|
|
||||||
# event_response = await self.client.room_get_event(room.room_id, reacted_to_id)
|
|
||||||
# if isinstance(event_response, RoomGetEventError):
|
|
||||||
# logger.warning(
|
|
||||||
# "Error getting event that was reacted to (%s)", reacted_to_id
|
|
||||||
# )
|
|
||||||
# return
|
|
||||||
# reacted_to_event = event_response.event
|
|
||||||
#
|
|
||||||
# # Only acknowledge reactions to events that we sent
|
|
||||||
# if reacted_to_event.sender != self.config.user_id:
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# # Send a message acknowledging the reaction
|
|
||||||
# reaction_sender_pill = make_pill(event.sender)
|
|
||||||
# reaction_content = (
|
|
||||||
# event.source.get("content", {}).get("m.relates_to", {}).get("key")
|
|
||||||
# )
|
|
||||||
# message = (
|
|
||||||
# f"{reaction_sender_pill} reacted to this event with `{reaction_content}`!"
|
|
||||||
# )
|
|
||||||
# await send_text_to_room(
|
|
||||||
# self.client,
|
|
||||||
# room.room_id,
|
|
||||||
# message,
|
|
||||||
# reply_to_event_id=reacted_to_id,
|
|
||||||
# )
|
|
||||||
|
|
||||||
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||||||
"""Callback for when an event fails to decrypt. Inform the user.
|
"""Callback for when an event fails to decrypt. Inform the user.
|
||||||
|
|
||||||
|
|
|
@ -112,17 +112,24 @@ def is_thread(event: RoomMessageText):
|
||||||
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
|
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
|
||||||
|
|
||||||
|
|
||||||
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_flag: str):
|
def check_command_prefix(string: str, prefixes: dict):
|
||||||
|
for k, v in prefixes.items():
|
||||||
|
if string.startswith(f'{v} '):
|
||||||
|
return True, k, v
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
|
||||||
|
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_prefixes: dict) -> tuple[bool, any, any]:
|
||||||
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
|
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
|
||||||
if base_event_id:
|
if base_event_id:
|
||||||
e = await client.room_get_event(room.room_id, base_event_id)
|
e = await client.room_get_event(room.room_id, base_event_id)
|
||||||
if not isinstance(e, RoomGetEventResponse):
|
if not isinstance(e, RoomGetEventResponse):
|
||||||
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
|
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
|
||||||
return
|
return False, None, None
|
||||||
else:
|
else:
|
||||||
return e.event.body.startswith(f'{command_flag} ')
|
return check_command_prefix(e.event.body, command_prefixes)
|
||||||
else:
|
else:
|
||||||
return False
|
return False, None, None
|
||||||
|
|
||||||
|
|
||||||
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
|
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
|
||||||
|
|
Loading…
Reference in New Issue