server-personification/pers/langchain/tools/document_manager.py

42 lines
1.4 KiB
Python

from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores.faiss import FAISS
from langchain_core.tools import tool
from langchain_openai import OpenAIEmbeddings
from pers import GLOBALS
class DocumentManager:
"""
A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops.
"""
index: FAISS
qa_chain: RetrievalQA
def __init__(self):
self.embeddings = OpenAIEmbeddings(openai_api_key=GLOBALS.OPENAI_KEY)
def load_data(self, data: str):
assert isinstance(data, str)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_text(data)
self.index = FAISS.from_texts(splits, self.embeddings)
def create_retrieval(self):
self.qa_chain = RetrievalQA.from_chain_type(
llm=GLOBALS.OpenAI,
chain_type="map_reduce",
retriever=self.index.as_retriever(),
)
return {'success': True, 'msg': 'Use the `retrieve_from_faiss` to parse the result'}
def retrieve(self, question: str):
return self.qa_chain.invoke({"query": question})
@tool
def retrieve_from_faiss(question: str):
"""Retrieve data from data embedded in FAISS"""
return GLOBALS.DocumentManager.retrieve(question)