42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
from langchain.chains import RetrievalQA
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.vectorstores.faiss import FAISS
|
|
from langchain_core.tools import tool
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
from pers import GLOBALS
|
|
|
|
|
|
class DocumentManager:
|
|
"""
|
|
A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops.
|
|
"""
|
|
index: FAISS
|
|
qa_chain: RetrievalQA
|
|
|
|
def __init__(self):
|
|
self.embeddings = OpenAIEmbeddings(openai_api_key=GLOBALS.OPENAI_KEY)
|
|
|
|
def load_data(self, data: str):
|
|
assert isinstance(data, str)
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
splits = text_splitter.split_text(data)
|
|
self.index = FAISS.from_texts(splits, self.embeddings)
|
|
|
|
def create_retrieval(self):
|
|
self.qa_chain = RetrievalQA.from_chain_type(
|
|
llm=GLOBALS.OpenAI,
|
|
chain_type="map_reduce",
|
|
retriever=self.index.as_retriever(),
|
|
)
|
|
return {'success': True, 'msg': 'Use the `retrieve_from_faiss` to parse the result'}
|
|
|
|
def retrieve(self, question: str):
|
|
return self.qa_chain.invoke({"query": question})
|
|
|
|
|
|
@tool
|
|
def retrieve_from_faiss(question: str):
|
|
"""Retrieve data from data embedded in FAISS"""
|
|
return GLOBALS.DocumentManager.retrieve(question)
|