64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
import code
|
|
import multiprocessing
|
|
import sys
|
|
from io import StringIO
|
|
from typing import Optional
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
from pers.langchain.tools.tools import PRINT_USAGE, _print_func_call
|
|
|
|
"""
|
|
This is very similar to PythonRELP() except that we execute our code in `code.InteractiveConsole()` which
|
|
is a better simulation of the REPL environment.
|
|
"""
|
|
|
|
|
|
def _run_py_worker(py_code: str, queue: multiprocessing.Queue) -> None:
|
|
old_stdout = sys.stdout
|
|
old_stderr = sys.stderr
|
|
|
|
console = code.InteractiveConsole()
|
|
|
|
sys.stdout = temp_stdout = StringIO()
|
|
sys.stderr = temp_stderr = StringIO()
|
|
|
|
console.push(py_code)
|
|
|
|
sys.stdout = old_stdout
|
|
sys.stderr = old_stderr
|
|
|
|
queue.put({'stdout': temp_stdout.getvalue(), 'stderr': temp_stderr.getvalue()})
|
|
|
|
|
|
@tool
|
|
def run_python(py_code: str, reasoning: str, timeout: Optional[int] = None) -> str:
|
|
"""Run command in a simulated REPL environment and returns anything printed.
|
|
Timeout after the specified number of seconds."""
|
|
|
|
if PRINT_USAGE:
|
|
_print_func_call('python', {'code': py_code, 'reasoning': reasoning})
|
|
|
|
queue = multiprocessing.Queue()
|
|
|
|
# Only use multiprocessing if we are enforcing a timeout
|
|
if timeout is not None:
|
|
# create a Process
|
|
p = multiprocessing.Process(
|
|
target=_run_py_worker, args=(py_code, queue)
|
|
)
|
|
|
|
# start it
|
|
p.start()
|
|
|
|
# wait for the process to finish or kill it after timeout seconds
|
|
p.join(timeout)
|
|
|
|
if p.is_alive():
|
|
p.terminate()
|
|
return "Execution timed out"
|
|
else:
|
|
_run_py_worker(py_code, queue)
|
|
# get the result from the worker function
|
|
return queue.get()
|