add local model support
This commit is contained in:
parent
d53517e3fb
commit
b47e1af873
|
@ -34,6 +34,14 @@ command:
|
|||
|
||||
reply_in_thread: true
|
||||
|
||||
# The bot can add extra debug info to the sent messages in the format:
|
||||
#"m.matrixgpt": {
|
||||
# "error": "",
|
||||
# "msg": ""
|
||||
#}
|
||||
# This info is only visible using "View Source"
|
||||
send_extra_messages: true
|
||||
|
||||
logging:
|
||||
log_level: info
|
||||
|
||||
|
@ -42,7 +50,11 @@ logging:
|
|||
|
||||
logout_other_devices: false
|
||||
|
||||
|
||||
|
||||
openai:
|
||||
# api_base: https://your-custom-backend/v1
|
||||
|
||||
api_key: sk-J12J3O12U3J1LK2J310283JIJ1L2K3J
|
||||
|
||||
model: gpt-3.5-turbo
|
||||
|
|
19
main.py
19
main.py
|
@ -55,11 +55,6 @@ check_config_value_exists(config_data, 'logging')
|
|||
check_config_value_exists(config_data['logging'], 'log_level')
|
||||
|
||||
check_config_value_exists(config_data, 'openai')
|
||||
check_config_value_exists(config_data['openai'], 'api_key')
|
||||
# check_config_value_exists(config_data['openai'], 'model')
|
||||
|
||||
# gpt4_enabled = True if config_data['command'].get('gpt4_prefix') else False
|
||||
# logger.info(f'GPT4 enabled? {gpt4_enabled}')
|
||||
|
||||
command_prefixes = {}
|
||||
for k, v in config_data['command'].items():
|
||||
|
@ -109,6 +104,14 @@ async def main():
|
|||
logger.info(f'Log level is {l}')
|
||||
del l
|
||||
|
||||
if len(config_data['command'].keys()) == 1 and config_data['command'][list(config_data['command'].keys())[0]]['mode'] == 'local':
|
||||
# Need the logger to be initalized for this
|
||||
logger.info('Running in local mode, OpenAI API key not required.')
|
||||
openai.api_key = 'abc123'
|
||||
else:
|
||||
check_config_value_exists(config_data['openai'], 'api_key')
|
||||
openai.api_key = config_data['openai']['api_key']
|
||||
|
||||
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
|
||||
|
||||
# Logging in with a new device each time seems to fix encryption errors
|
||||
|
@ -124,7 +127,9 @@ async def main():
|
|||
)
|
||||
client = matrix_helper.client
|
||||
|
||||
openai.api_key = config_data['openai']['api_key']
|
||||
if config_data['openai'].get('api_base'):
|
||||
logger.info(f'Set OpenAI API base URL to: {config_data["openai"].get("api_base")}')
|
||||
openai.api_base = config_data['openai'].get('api_base')
|
||||
|
||||
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
|
||||
|
||||
|
@ -140,7 +145,7 @@ async def main():
|
|||
system_prompt=config_data['openai'].get('system_prompt'),
|
||||
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
||||
openai_temperature=config_data['openai'].get('temperature', 0),
|
||||
# gpt4_enabled=gpt4_enabled,
|
||||
send_extra_messages=config_data.get('send_extra_messages', False),
|
||||
log_level=log_level
|
||||
)
|
||||
client.add_event_callback(callbacks.message, RoomMessageText)
|
||||
|
|
|
@ -26,7 +26,8 @@ class Command:
|
|||
openai_temperature: float = 0,
|
||||
system_prompt: str = None,
|
||||
injected_system_prompt: str = None,
|
||||
log_full_response: bool = False
|
||||
log_full_response: bool = False,
|
||||
send_extra_messages: bool = True
|
||||
):
|
||||
"""A command made by a user.
|
||||
|
||||
|
@ -57,6 +58,7 @@ class Command:
|
|||
self.log_full_response = log_full_response
|
||||
self.openai_obj = openai_obj
|
||||
self.openai_temperature = openai_temperature
|
||||
self.send_extra_messages = send_extra_messages
|
||||
|
||||
async def process(self):
|
||||
"""Process the command"""
|
||||
|
@ -88,7 +90,8 @@ class Command:
|
|||
openai_temperature=self.openai_temperature,
|
||||
system_prompt=self.system_prompt,
|
||||
injected_system_prompt=self.injected_system_prompt,
|
||||
log_full_response=self.log_full_response
|
||||
log_full_response=self.log_full_response,
|
||||
send_extra_messages=self.send_extra_messages
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().create_task(inner())
|
||||
|
|
|
@ -30,8 +30,8 @@ class Callbacks:
|
|||
log_full_response: bool = False,
|
||||
injected_system_prompt: str = False,
|
||||
openai_temperature: float = 0,
|
||||
log_level=logging.INFO
|
||||
# gpt4_enabled: bool = False,
|
||||
log_level=logging.INFO,
|
||||
send_extra_messages: bool = False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -55,6 +55,7 @@ class Callbacks:
|
|||
self.openai_temperature = openai_temperature
|
||||
# self.gpt4_enabled = gpt4_enabled
|
||||
self.log_level = log_level
|
||||
self.send_extra_messages = send_extra_messages
|
||||
|
||||
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
"""Callback for when a message event is received
|
||||
|
@ -143,7 +144,8 @@ class Callbacks:
|
|||
thread_root_id=thread_content[0].event_id,
|
||||
system_prompt=self.system_prompt,
|
||||
log_full_response=self.log_full_response,
|
||||
injected_system_prompt=self.injected_system_prompt
|
||||
injected_system_prompt=self.injected_system_prompt,
|
||||
send_extra_messages=self.send_extra_messages
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().create_task(inner())
|
||||
|
@ -152,7 +154,7 @@ class Callbacks:
|
|||
raise
|
||||
return
|
||||
elif (command_activated or room.member_count == 2) and not is_thread(event): # Everything else
|
||||
if not check_authorized(event.sender, command_info['allowed_to_chat']):
|
||||
if command_info.get('allowed_to_chat') and not check_authorized(event.sender, command_info['allowed_to_chat']):
|
||||
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
||||
return
|
||||
try:
|
||||
|
|
|
@ -60,6 +60,8 @@ async def send_text_to_room(client: AsyncClient, room_id: str, message: str, not
|
|||
"event_id": reply_to_event_id
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: don't force this to string. what if we want to send an array?
|
||||
content["m.matrixgpt"] = {
|
||||
"error": str(extra_error),
|
||||
"msg": str(extra_msg),
|
||||
|
@ -193,7 +195,8 @@ async def process_chat(
|
|||
thread_root_id: str = None,
|
||||
system_prompt: str = None,
|
||||
log_full_response: bool = False,
|
||||
injected_system_prompt: str = False
|
||||
injected_system_prompt: str = False,
|
||||
send_extra_messages: bool = True
|
||||
):
|
||||
try:
|
||||
if not store.check_seen_event(event.event_id):
|
||||
|
@ -226,16 +229,17 @@ async def process_chat(
|
|||
# I don't think the OpenAI py api has a built-in timeout
|
||||
@stopit.threading_timeoutable(default=(None, None))
|
||||
async def generate():
|
||||
if openai_model.startswith('gpt-3') or openai_model.startswith('gpt-4'):
|
||||
if openai_model.startswith('gpt-3') or openai_model.startswith('gpt-4') or openai_model == 'local':
|
||||
r = await loop.run_in_executor(None, functools.partial(openai_obj.ChatCompletion.create,
|
||||
model=openai_model, messages=messages,
|
||||
temperature=openai_temperature, timeout=20))
|
||||
temperature=openai_temperature, timeout=900, max_tokens=None if openai_model != 'local' else 320))
|
||||
return r.choices[0].message.content
|
||||
elif openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
||||
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
||||
r = await loop.run_in_executor(None,
|
||||
functools.partial(openai_obj.Completion.create, model=openai_model,
|
||||
temperature=openai_temperature, timeout=20,
|
||||
temperature=openai_temperature,
|
||||
request_timeout=900,
|
||||
max_tokens=4096))
|
||||
return r.choices[0].text
|
||||
else:
|
||||
|
@ -243,10 +247,10 @@ async def process_chat(
|
|||
|
||||
response = None
|
||||
openai_gen_error = None
|
||||
for i in range(openai_retries):
|
||||
for i in range(1, openai_retries):
|
||||
sleep_time = i * 5
|
||||
try:
|
||||
task = asyncio.create_task(generate(timeout=20))
|
||||
task = asyncio.create_task(generate(timeout=900))
|
||||
asyncio.as_completed(task)
|
||||
response = await task
|
||||
if response is not None:
|
||||
|
@ -267,7 +271,8 @@ async def process_chat(
|
|||
if response is None:
|
||||
logger.critical(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=15000)
|
||||
await react_to_event(client, room.room_id, event.event_id, '❌', extra_error=openai_gen_error)
|
||||
await react_to_event(client, room.room_id, event.event_id, '❌',
|
||||
extra_error=(openai_gen_error if send_extra_messages else False))
|
||||
return
|
||||
text_response = response.strip().strip('\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue