2024-04-07 19:41:19 -06:00
import asyncio
2024-04-10 22:47:15 -06:00
import json
2024-04-07 19:41:19 -06:00
import logging
import traceback
from typing import Union
2024-04-07 22:44:27 -06:00
from nio import RoomSendResponse , MatrixRoom , RoomMessageText
2024-04-07 19:41:19 -06:00
from matrix_gpt import MatrixClientHelper
2024-04-07 22:27:00 -06:00
from matrix_gpt . api_client_manager import api_client_helper
2024-04-07 19:41:19 -06:00
from matrix_gpt . config import global_config
2024-04-07 22:27:00 -06:00
from matrix_gpt . generate_clients . command_info import CommandInfo
2024-04-07 19:41:19 -06:00
2024-04-08 00:11:19 -06:00
logger = logging . getLogger ( ' MatrixGPT ' ) . getChild ( ' Generate ' )
2024-04-07 19:41:19 -06:00
# TODO: process_chat() will set typing as false after generating.
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
2024-04-07 22:27:00 -06:00
2024-04-07 19:41:19 -06:00
async def generate_ai_response (
client_helper : MatrixClientHelper ,
2024-04-07 22:44:27 -06:00
room : MatrixRoom ,
event : RoomMessageText ,
2024-04-10 22:47:15 -06:00
context : Union [ str , list ] ,
2024-04-07 22:27:00 -06:00
command_info : CommandInfo ,
2024-04-07 19:41:19 -06:00
thread_root_id : str = None ,
2024-04-10 22:47:15 -06:00
matrix_gpt_data : str = None
2024-04-07 19:41:19 -06:00
) :
2024-04-07 22:27:00 -06:00
assert isinstance ( command_info , CommandInfo )
2024-04-07 19:41:19 -06:00
client = client_helper . client
try :
await client . room_typing ( room . room_id , typing_state = True , timeout = global_config [ ' response_timeout ' ] * 1000 )
2024-04-10 22:47:15 -06:00
api_client = api_client_helper . get_client ( command_info . api_type , client_helper , room , event )
2024-04-10 18:16:36 -06:00
if not api_client :
# If this was None then we were missing an API key for this client type. Error has already been logged.
await client_helper . react_to_event (
room . room_id ,
event . event_id ,
' ❌ ' ,
extra_error = f ' No API key for model { command_info . model } ' if global_config [ ' send_extra_messages ' ] else None
)
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
return
2024-04-10 22:47:15 -06:00
# The input context can be either a string if this is the first message in the thread or a list of all messages in the thread.
# Handling this here instead of the caller simplifies things.
if isinstance ( context , str ) :
context = [ { ' role ' : api_client . HUMAN_NAME , ' content ' : context } ]
# Build the context and do the things that need to be done for our specific API type.
2024-04-11 14:32:38 -06:00
api_client . prepare_context ( context , system_prompt = command_info . system_prompt , injected_system_prompt = command_info . injected_system_prompt )
2024-04-07 19:41:19 -06:00
2024-04-10 16:42:52 -06:00
if api_client . check_ignore_request ( ) :
logger . debug ( f ' Reply to { event . event_id } was ignored by the model " { command_info . model } " . ' )
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
return
2024-04-07 19:41:19 -06:00
response = None
2024-04-10 22:47:15 -06:00
extra_data = None
2024-04-07 19:41:19 -06:00
try :
2024-04-10 22:47:15 -06:00
task = asyncio . create_task ( api_client . generate ( command_info , matrix_gpt_data ) )
2024-04-07 19:41:19 -06:00
for task in asyncio . as_completed ( [ task ] , timeout = global_config [ ' response_timeout ' ] ) :
2024-04-10 16:42:52 -06:00
# TODO: add a while loop and heartbeat the background thread
2024-04-07 19:41:19 -06:00
try :
2024-04-10 22:47:15 -06:00
response , extra_data = await task
2024-04-07 19:41:19 -06:00
break
except asyncio . TimeoutError :
logger . warning ( f ' Response to event { event . event_id } timed out. ' )
await client_helper . react_to_event (
room . room_id ,
event . event_id ,
' 🕒 ' ,
extra_error = ' Request timed out. ' if global_config [ ' send_extra_messages ' ] else None
)
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
return
except Exception :
logger . error ( f ' Exception when generating for event { event . event_id } : { traceback . format_exc ( ) } ' )
await client_helper . react_to_event (
room . room_id ,
event . event_id ,
' ❌ ' ,
2024-04-11 14:32:38 -06:00
extra_error = ' Exception while generating AI response ' if global_config [ ' send_extra_messages ' ] else None
2024-04-07 19:41:19 -06:00
)
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
return
if not response :
logger . warning ( f ' Response to event { event . event_id } in room { room . room_id } was null. ' )
await client_helper . react_to_event (
room . room_id ,
event . event_id ,
' ❌ ' ,
2024-04-11 14:32:38 -06:00
extra_error = ' AI response was empty ' if global_config [ ' send_extra_messages ' ] else None
2024-04-07 19:41:19 -06:00
)
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
return
# The AI's response.
text_response = response . strip ( ) . strip ( ' \n ' )
2024-04-10 22:47:15 -06:00
if not extra_data :
extra_data = { }
2024-04-07 19:41:19 -06:00
# Logging
if global_config [ ' logging ' ] [ ' log_full_response ' ] :
2024-04-10 22:47:15 -06:00
assembled_context = api_client . context
data = { ' event_id ' : event . event_id , ' room ' : room . room_id , ' messages ' : assembled_context , ' response ' : response }
2024-04-09 19:26:44 -06:00
# Remove images from the logged data.
for i in range ( len ( data [ ' messages ' ] ) ) :
if isinstance ( data [ ' messages ' ] [ i ] [ ' content ' ] , list ) :
# Images are always sent as lists
if data [ ' messages ' ] [ i ] [ ' content ' ] [ 0 ] . get ( ' source ' , { } ) . get ( ' media_type ' ) :
# Anthropic
data [ ' messages ' ] [ i ] [ ' content ' ] [ 0 ] [ ' source ' ] [ ' data ' ] = ' ... '
elif data [ ' messages ' ] [ i ] [ ' content ' ] [ 0 ] . get ( ' image_url ' ) :
# OpenAI
data [ ' messages ' ] [ i ] [ ' content ' ] [ 0 ] [ ' image_url ' ] [ ' url ' ] = ' ... '
2024-04-10 22:47:15 -06:00
logger . debug ( json . dumps ( data ) )
2024-04-07 19:41:19 -06:00
z = text_response . replace ( " \n " , " \\ n " )
2024-04-07 22:27:00 -06:00
logger . info ( f ' Reply to { event . event_id } --> { command_info . model } responded with " { z } " ' )
2024-04-07 19:41:19 -06:00
# Send message to room
resp = await client_helper . send_text_to_room (
room . room_id ,
text_response ,
reply_to_event_id = event . event_id ,
thread = True ,
2024-04-07 22:44:27 -06:00
thread_root_id = thread_root_id if thread_root_id else event . event_id ,
2024-04-10 22:47:15 -06:00
markdown_convert = True ,
extra_data = extra_data
2024-04-07 19:41:19 -06:00
)
await client . room_typing ( room . room_id , typing_state = False , timeout = 1000 )
if not isinstance ( resp , RoomSendResponse ) :
logger . critical ( f ' Failed to respond to event { event . event_id } in room { room . room_id } : \n { vars ( resp ) } ' )
2024-04-11 14:32:38 -06:00
await client_helper . react_to_event ( room . room_id , event . event_id , ' ❌ ' , extra_error = ' Exception while responding to event ' if global_config [ ' send_extra_messages ' ] else None )
except Exception as e :
await client_helper . react_to_event ( room . room_id , event . event_id , ' ❌ ' , extra_error = f ' Exception during response process: { e } ' if global_config [ ' send_extra_messages ' ] else None )
2024-04-07 19:41:19 -06:00
raise