add local model support
This commit is contained in:
parent
d53517e3fb
commit
b47e1af873
|
@ -34,6 +34,14 @@ command:
|
||||||
|
|
||||||
reply_in_thread: true
|
reply_in_thread: true
|
||||||
|
|
||||||
|
# The bot can add extra debug info to the sent messages in the format:
|
||||||
|
#"m.matrixgpt": {
|
||||||
|
# "error": "",
|
||||||
|
# "msg": ""
|
||||||
|
#}
|
||||||
|
# This info is only visible using "View Source"
|
||||||
|
send_extra_messages: true
|
||||||
|
|
||||||
logging:
|
logging:
|
||||||
log_level: info
|
log_level: info
|
||||||
|
|
||||||
|
@ -42,7 +50,11 @@ logging:
|
||||||
|
|
||||||
logout_other_devices: false
|
logout_other_devices: false
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
openai:
|
openai:
|
||||||
|
# api_base: https://your-custom-backend/v1
|
||||||
|
|
||||||
api_key: sk-J12J3O12U3J1LK2J310283JIJ1L2K3J
|
api_key: sk-J12J3O12U3J1LK2J310283JIJ1L2K3J
|
||||||
|
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
|
|
19
main.py
19
main.py
|
@ -55,11 +55,6 @@ check_config_value_exists(config_data, 'logging')
|
||||||
check_config_value_exists(config_data['logging'], 'log_level')
|
check_config_value_exists(config_data['logging'], 'log_level')
|
||||||
|
|
||||||
check_config_value_exists(config_data, 'openai')
|
check_config_value_exists(config_data, 'openai')
|
||||||
check_config_value_exists(config_data['openai'], 'api_key')
|
|
||||||
# check_config_value_exists(config_data['openai'], 'model')
|
|
||||||
|
|
||||||
# gpt4_enabled = True if config_data['command'].get('gpt4_prefix') else False
|
|
||||||
# logger.info(f'GPT4 enabled? {gpt4_enabled}')
|
|
||||||
|
|
||||||
command_prefixes = {}
|
command_prefixes = {}
|
||||||
for k, v in config_data['command'].items():
|
for k, v in config_data['command'].items():
|
||||||
|
@ -109,6 +104,14 @@ async def main():
|
||||||
logger.info(f'Log level is {l}')
|
logger.info(f'Log level is {l}')
|
||||||
del l
|
del l
|
||||||
|
|
||||||
|
if len(config_data['command'].keys()) == 1 and config_data['command'][list(config_data['command'].keys())[0]]['mode'] == 'local':
|
||||||
|
# Need the logger to be initalized for this
|
||||||
|
logger.info('Running in local mode, OpenAI API key not required.')
|
||||||
|
openai.api_key = 'abc123'
|
||||||
|
else:
|
||||||
|
check_config_value_exists(config_data['openai'], 'api_key')
|
||||||
|
openai.api_key = config_data['openai']['api_key']
|
||||||
|
|
||||||
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
|
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
|
||||||
|
|
||||||
# Logging in with a new device each time seems to fix encryption errors
|
# Logging in with a new device each time seems to fix encryption errors
|
||||||
|
@ -124,7 +127,9 @@ async def main():
|
||||||
)
|
)
|
||||||
client = matrix_helper.client
|
client = matrix_helper.client
|
||||||
|
|
||||||
openai.api_key = config_data['openai']['api_key']
|
if config_data['openai'].get('api_base'):
|
||||||
|
logger.info(f'Set OpenAI API base URL to: {config_data["openai"].get("api_base")}')
|
||||||
|
openai.api_base = config_data['openai'].get('api_base')
|
||||||
|
|
||||||
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
|
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
|
||||||
|
|
||||||
|
@ -140,7 +145,7 @@ async def main():
|
||||||
system_prompt=config_data['openai'].get('system_prompt'),
|
system_prompt=config_data['openai'].get('system_prompt'),
|
||||||
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
||||||
openai_temperature=config_data['openai'].get('temperature', 0),
|
openai_temperature=config_data['openai'].get('temperature', 0),
|
||||||
# gpt4_enabled=gpt4_enabled,
|
send_extra_messages=config_data.get('send_extra_messages', False),
|
||||||
log_level=log_level
|
log_level=log_level
|
||||||
)
|
)
|
||||||
client.add_event_callback(callbacks.message, RoomMessageText)
|
client.add_event_callback(callbacks.message, RoomMessageText)
|
||||||
|
|
|
@ -26,7 +26,8 @@ class Command:
|
||||||
openai_temperature: float = 0,
|
openai_temperature: float = 0,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
injected_system_prompt: str = None,
|
injected_system_prompt: str = None,
|
||||||
log_full_response: bool = False
|
log_full_response: bool = False,
|
||||||
|
send_extra_messages: bool = True
|
||||||
):
|
):
|
||||||
"""A command made by a user.
|
"""A command made by a user.
|
||||||
|
|
||||||
|
@ -57,6 +58,7 @@ class Command:
|
||||||
self.log_full_response = log_full_response
|
self.log_full_response = log_full_response
|
||||||
self.openai_obj = openai_obj
|
self.openai_obj = openai_obj
|
||||||
self.openai_temperature = openai_temperature
|
self.openai_temperature = openai_temperature
|
||||||
|
self.send_extra_messages = send_extra_messages
|
||||||
|
|
||||||
async def process(self):
|
async def process(self):
|
||||||
"""Process the command"""
|
"""Process the command"""
|
||||||
|
@ -88,7 +90,8 @@ class Command:
|
||||||
openai_temperature=self.openai_temperature,
|
openai_temperature=self.openai_temperature,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
injected_system_prompt=self.injected_system_prompt,
|
injected_system_prompt=self.injected_system_prompt,
|
||||||
log_full_response=self.log_full_response
|
log_full_response=self.log_full_response,
|
||||||
|
send_extra_messages=self.send_extra_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.get_event_loop().create_task(inner())
|
asyncio.get_event_loop().create_task(inner())
|
||||||
|
|
|
@ -30,8 +30,8 @@ class Callbacks:
|
||||||
log_full_response: bool = False,
|
log_full_response: bool = False,
|
||||||
injected_system_prompt: str = False,
|
injected_system_prompt: str = False,
|
||||||
openai_temperature: float = 0,
|
openai_temperature: float = 0,
|
||||||
log_level=logging.INFO
|
log_level=logging.INFO,
|
||||||
# gpt4_enabled: bool = False,
|
send_extra_messages: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -55,6 +55,7 @@ class Callbacks:
|
||||||
self.openai_temperature = openai_temperature
|
self.openai_temperature = openai_temperature
|
||||||
# self.gpt4_enabled = gpt4_enabled
|
# self.gpt4_enabled = gpt4_enabled
|
||||||
self.log_level = log_level
|
self.log_level = log_level
|
||||||
|
self.send_extra_messages = send_extra_messages
|
||||||
|
|
||||||
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
"""Callback for when a message event is received
|
"""Callback for when a message event is received
|
||||||
|
@ -143,7 +144,8 @@ class Callbacks:
|
||||||
thread_root_id=thread_content[0].event_id,
|
thread_root_id=thread_content[0].event_id,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=self.system_prompt,
|
||||||
log_full_response=self.log_full_response,
|
log_full_response=self.log_full_response,
|
||||||
injected_system_prompt=self.injected_system_prompt
|
injected_system_prompt=self.injected_system_prompt,
|
||||||
|
send_extra_messages=self.send_extra_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.get_event_loop().create_task(inner())
|
asyncio.get_event_loop().create_task(inner())
|
||||||
|
@ -152,7 +154,7 @@ class Callbacks:
|
||||||
raise
|
raise
|
||||||
return
|
return
|
||||||
elif (command_activated or room.member_count == 2) and not is_thread(event): # Everything else
|
elif (command_activated or room.member_count == 2) and not is_thread(event): # Everything else
|
||||||
if not check_authorized(event.sender, command_info['allowed_to_chat']):
|
if command_info.get('allowed_to_chat') and not check_authorized(event.sender, command_info['allowed_to_chat']):
|
||||||
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -60,6 +60,8 @@ async def send_text_to_room(client: AsyncClient, room_id: str, message: str, not
|
||||||
"event_id": reply_to_event_id
|
"event_id": reply_to_event_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO: don't force this to string. what if we want to send an array?
|
||||||
content["m.matrixgpt"] = {
|
content["m.matrixgpt"] = {
|
||||||
"error": str(extra_error),
|
"error": str(extra_error),
|
||||||
"msg": str(extra_msg),
|
"msg": str(extra_msg),
|
||||||
|
@ -193,7 +195,8 @@ async def process_chat(
|
||||||
thread_root_id: str = None,
|
thread_root_id: str = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
log_full_response: bool = False,
|
log_full_response: bool = False,
|
||||||
injected_system_prompt: str = False
|
injected_system_prompt: str = False,
|
||||||
|
send_extra_messages: bool = True
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if not store.check_seen_event(event.event_id):
|
if not store.check_seen_event(event.event_id):
|
||||||
|
@ -226,16 +229,17 @@ async def process_chat(
|
||||||
# I don't think the OpenAI py api has a built-in timeout
|
# I don't think the OpenAI py api has a built-in timeout
|
||||||
@stopit.threading_timeoutable(default=(None, None))
|
@stopit.threading_timeoutable(default=(None, None))
|
||||||
async def generate():
|
async def generate():
|
||||||
if openai_model.startswith('gpt-3') or openai_model.startswith('gpt-4'):
|
if openai_model.startswith('gpt-3') or openai_model.startswith('gpt-4') or openai_model == 'local':
|
||||||
r = await loop.run_in_executor(None, functools.partial(openai_obj.ChatCompletion.create,
|
r = await loop.run_in_executor(None, functools.partial(openai_obj.ChatCompletion.create,
|
||||||
model=openai_model, messages=messages,
|
model=openai_model, messages=messages,
|
||||||
temperature=openai_temperature, timeout=20))
|
temperature=openai_temperature, timeout=900, max_tokens=None if openai_model != 'local' else 320))
|
||||||
return r.choices[0].message.content
|
return r.choices[0].message.content
|
||||||
elif openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
elif openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
||||||
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
||||||
r = await loop.run_in_executor(None,
|
r = await loop.run_in_executor(None,
|
||||||
functools.partial(openai_obj.Completion.create, model=openai_model,
|
functools.partial(openai_obj.Completion.create, model=openai_model,
|
||||||
temperature=openai_temperature, timeout=20,
|
temperature=openai_temperature,
|
||||||
|
request_timeout=900,
|
||||||
max_tokens=4096))
|
max_tokens=4096))
|
||||||
return r.choices[0].text
|
return r.choices[0].text
|
||||||
else:
|
else:
|
||||||
|
@ -243,10 +247,10 @@ async def process_chat(
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
openai_gen_error = None
|
openai_gen_error = None
|
||||||
for i in range(openai_retries):
|
for i in range(1, openai_retries):
|
||||||
sleep_time = i * 5
|
sleep_time = i * 5
|
||||||
try:
|
try:
|
||||||
task = asyncio.create_task(generate(timeout=20))
|
task = asyncio.create_task(generate(timeout=900))
|
||||||
asyncio.as_completed(task)
|
asyncio.as_completed(task)
|
||||||
response = await task
|
response = await task
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
@ -267,7 +271,8 @@ async def process_chat(
|
||||||
if response is None:
|
if response is None:
|
||||||
logger.critical(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
logger.critical(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
||||||
await client.room_typing(room.room_id, typing_state=False, timeout=15000)
|
await client.room_typing(room.room_id, typing_state=False, timeout=15000)
|
||||||
await react_to_event(client, room.room_id, event.event_id, '❌', extra_error=openai_gen_error)
|
await react_to_event(client, room.room_id, event.event_id, '❌',
|
||||||
|
extra_error=(openai_gen_error if send_extra_messages else False))
|
||||||
return
|
return
|
||||||
text_response = response.strip().strip('\n')
|
text_response = response.strip().strip('\n')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue