136 lines
6.0 KiB
Python
136 lines
6.0 KiB
Python
import asyncio
|
|
import logging
|
|
import traceback
|
|
from typing import Union
|
|
|
|
from nio import RoomSendResponse
|
|
|
|
from matrix_gpt import MatrixClientHelper
|
|
from matrix_gpt.config import global_config
|
|
from matrix_gpt.openai_client import openai_client
|
|
|
|
logger = logging.getLogger('ProcessChat')
|
|
|
|
|
|
# TODO: process_chat() will set typing as false after generating.
|
|
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
|
|
|
|
async def generate_ai_response(
|
|
client_helper: MatrixClientHelper,
|
|
room,
|
|
event,
|
|
msg: Union[str, list],
|
|
sent_command_prefix: str,
|
|
openai_model: str,
|
|
thread_root_id: str = None,
|
|
):
|
|
client = client_helper.client
|
|
try:
|
|
await client.room_typing(room.room_id, typing_state=True, timeout=global_config['response_timeout'] * 1000)
|
|
|
|
# Set up the messages list.
|
|
if isinstance(msg, list):
|
|
messages = msg
|
|
else:
|
|
messages = [{'role': 'user', 'content': msg}]
|
|
|
|
# Inject the system prompt.
|
|
system_prompt = global_config['openai'].get('system_prompt', '')
|
|
injected_system_prompt = global_config['openai'].get('injected_system_prompt', '')
|
|
if isinstance(system_prompt, str) and len(system_prompt):
|
|
messages.insert(0, {"role": "system", "content": global_config['openai']['system_prompt']})
|
|
if (isinstance(injected_system_prompt, str) and len(injected_system_prompt)) and len(messages) >= 3:
|
|
# Only inject the system prompt if this isn't the first reply.
|
|
if messages[-1]['role'] == 'system':
|
|
# Delete the last system message since we want to replace it with our inject prompt.
|
|
del messages[-1]
|
|
messages.insert(-1, {"role": "system", "content": global_config['openai']['injected_system_prompt']})
|
|
|
|
max_tokens = global_config.command_prefixes[sent_command_prefix]['max_tokens']
|
|
|
|
async def generate():
|
|
if openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
|
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
|
r = await openai_client.client().completions.create(
|
|
model=openai_model,
|
|
temperature=global_config['openai']['temperature'],
|
|
request_timeout=global_config['response_timeout'],
|
|
max_tokens=None if max_tokens == 0 else max_tokens
|
|
)
|
|
return r.choices[0].text
|
|
else:
|
|
r = await openai_client.client().chat.completions.create(
|
|
model=openai_model, messages=messages,
|
|
temperature=global_config['openai']['temperature'],
|
|
timeout=global_config['response_timeout'],
|
|
max_tokens=None if max_tokens == 0 else max_tokens
|
|
)
|
|
return r.choices[0].message.content
|
|
|
|
response = None
|
|
try:
|
|
task = asyncio.create_task(generate())
|
|
for task in asyncio.as_completed([task], timeout=global_config['response_timeout']):
|
|
try:
|
|
response = await task
|
|
break
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f'Response to event {event.event_id} timed out.')
|
|
await client_helper.react_to_event(
|
|
room.room_id,
|
|
event.event_id,
|
|
'🕒',
|
|
extra_error='Request timed out.' if global_config['send_extra_messages'] else None
|
|
)
|
|
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
|
return
|
|
except Exception:
|
|
logger.error(f'Exception when generating for event {event.event_id}: {traceback.format_exc()}')
|
|
await client_helper.react_to_event(
|
|
room.room_id,
|
|
event.event_id,
|
|
'❌',
|
|
extra_error='Exception' if global_config['send_extra_messages'] else None
|
|
)
|
|
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
|
return
|
|
|
|
if not response:
|
|
logger.warning(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
|
await client_helper.react_to_event(
|
|
room.room_id,
|
|
event.event_id,
|
|
'❌',
|
|
extra_error='Response was null.' if global_config['send_extra_messages'] else None
|
|
)
|
|
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
|
return
|
|
|
|
# The AI's response.
|
|
text_response = response.strip().strip('\n')
|
|
|
|
# Logging
|
|
if global_config['logging']['log_full_response']:
|
|
logger.debug(
|
|
{'event_id': event.event_id, 'room': room.room_id, 'messages': messages, 'response': response}
|
|
)
|
|
z = text_response.replace("\n", "\\n")
|
|
logger.info(f'Reply to {event.event_id} --> {openai_model} responded with "{z}"')
|
|
|
|
# Send message to room
|
|
resp = await client_helper.send_text_to_room(
|
|
room.room_id,
|
|
text_response,
|
|
reply_to_event_id=event.event_id,
|
|
thread=True,
|
|
thread_root_id=thread_root_id if thread_root_id else event.event_id
|
|
)
|
|
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
|
if not isinstance(resp, RoomSendResponse):
|
|
logger.critical(f'Failed to respond to event {event.event_id} in room {room.room_id}:\n{vars(resp)}')
|
|
await client_helper.react_to_event(room.room_id, event.event_id, '❌', extra_error='Exception' if global_config['send_extra_messages'] else None)
|
|
except Exception:
|
|
await client_helper.react_to_event(room.room_id, event.event_id, '❌', extra_error='Exception' if global_config['send_extra_messages'] else None)
|
|
raise
|
|
|