157 lines
6.2 KiB
Python
Executable File
157 lines
6.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import os
|
|
import re
|
|
import readline
|
|
import signal
|
|
import socket
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from langchain.agents import AgentExecutor
|
|
from langchain.agents.format_scratchpad import format_to_openai_function_messages
|
|
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
|
|
from langchain_core.messages import SystemMessage
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
|
from langchain_openai import ChatOpenAI
|
|
from termcolor import colored
|
|
|
|
import pers
|
|
from pers.langchain.agent import AgentInput
|
|
from pers.langchain.callbacks import CallbackHandler
|
|
from pers.langchain.history import HistoryManager
|
|
from pers.langchain.tools.agent_end_response import end_my_response
|
|
from pers.langchain.tools.bash import run_bash
|
|
from pers.langchain.tools.browser import render_webpage_tool
|
|
from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_chroma
|
|
from pers.langchain.tools.google import search_google, search_google_news, search_google_maps
|
|
from pers.langchain.tools.python import run_python
|
|
from pers.langchain.tools.terminate import terminate_chat
|
|
from pers.langchain.tools.web_reader import read_webpage
|
|
from pers.load import load_config
|
|
from pers.personality import load_personality
|
|
|
|
|
|
def signal_handler(sig, frame):
|
|
print()
|
|
sys.exit(0)
|
|
|
|
|
|
# Keep pycharm from removing this import.
|
|
readline.get_completion_type()
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
MANUAL_STOP_RE = re.compile(r'(\n|\s)*(functions\.)*end_my_response(\(\))*(.*)', re.MULTILINE)
|
|
|
|
|
|
def init():
|
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
|
program_config = load_config(Path(script_path, 'config.yml'), ['openai_key', 'player_name', 'flush_redis_on_launch'])
|
|
character_config = load_config(Path(script_path, 'character.yml'), ['name', 'personality', 'system_desc', 'gender', 'temperature', 'model'])
|
|
|
|
pers.GLOBALS.OPENAI_KEY = program_config['openai_key']
|
|
pers.GLOBALS.ChatHistory = HistoryManager(
|
|
flush=program_config['flush_redis_on_launch'],
|
|
timestamp_messages=program_config.get('timestamp_messages', False),
|
|
)
|
|
pers.GLOBALS.DocumentManager = DocumentManager()
|
|
if program_config.get('serpapi_api_key'):
|
|
pers.GLOBALS.SERPAPI_API_KEY = program_config['serpapi_api_key']
|
|
pers.GLOBALS.OpenAI = ChatOpenAI(model_name=character_config['model'], openai_api_key=program_config['openai_key'], temperature=character_config['temperature'])
|
|
|
|
character_card = load_personality(
|
|
player_name=program_config['player_name'],
|
|
name=character_config['name'],
|
|
personality=character_config['personality'],
|
|
system=character_config['system_desc'],
|
|
gender=character_config['gender'],
|
|
special_instructions=character_config.get('special_instructions'),
|
|
player_location=program_config.get('player_location')
|
|
)
|
|
return program_config, character_config, character_card
|
|
|
|
|
|
def main():
|
|
program_config, character_config, character_card = init()
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", character_card),
|
|
MessagesPlaceholder(variable_name="chat_history"),
|
|
("user", "{input}"),
|
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
]
|
|
)
|
|
|
|
tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_chroma, read_webpage, render_webpage_tool, terminate_chat]
|
|
llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools])
|
|
|
|
agent = (
|
|
{
|
|
"input": lambda x: x["input"],
|
|
"chat_history": lambda x: x["chat_history"],
|
|
"agent_scratchpad": lambda x: format_to_openai_function_messages(
|
|
x["intermediate_steps"]
|
|
),
|
|
}
|
|
| prompt
|
|
| llm_with_tools
|
|
| OpenAIFunctionsAgentOutputParser()
|
|
)
|
|
|
|
handler = CallbackHandler()
|
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, callbacks=[handler]).with_types(
|
|
input_type=AgentInput
|
|
)
|
|
|
|
print(colored(f'System Management Intelligence Interface', 'green', attrs=['bold']) + ' ' + colored(character_config['name'], 'green', attrs=['bold', 'underline']) + colored(' on ', 'green', attrs=['bold']) + colored(socket.gethostname(), 'green', attrs=['bold', 'underline']) + '\n')
|
|
|
|
while True:
|
|
try:
|
|
next_input = str(input('> '))
|
|
except EOFError:
|
|
print('Exit')
|
|
sys.exit(0)
|
|
print('')
|
|
pers.GLOBALS.ChatHistory.add_human_msg(next_input)
|
|
|
|
i = 0
|
|
while True:
|
|
temp_context = pers.GLOBALS.ChatHistory.history
|
|
if i > 0:
|
|
# Insert a prompt if this is not the first message.
|
|
temp_context.append(
|
|
SystemMessage(content="Evaluate your progress on the current task. Call `end_my_response` after you have responded and are ready for the human's next message.")
|
|
)
|
|
|
|
if pers.GLOBALS.ChatHistory.acknowledge_stop():
|
|
break
|
|
|
|
result = agent_executor.invoke({"input": next_input, "chat_history": temp_context}, config=RunnableConfig(callbacks=[handler]))
|
|
|
|
# Langchain and the agent are really struggling with end_my_response.
|
|
# If the agent gets confused and puts the string "end_my_response" at the end of the msg rather than calling the function, end it manually.
|
|
do_stop = False
|
|
output = result['output']
|
|
if re.search(MANUAL_STOP_RE, output):
|
|
output = re.sub(MANUAL_STOP_RE, '', output)
|
|
do_stop = True
|
|
|
|
pers.GLOBALS.ChatHistory.add_agent_msg(output)
|
|
|
|
# We need to print each line individually since the colored text doesn't support the "\n" character
|
|
lines = output.split('\n')
|
|
for line in lines:
|
|
print(colored(line, 'blue'))
|
|
print()
|
|
|
|
if do_stop:
|
|
break
|
|
|
|
i += 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|