use chromadb instead of faiss
This commit is contained in:
parent
26a320dd2d
commit
f4848423b8
|
@ -32,7 +32,7 @@ def render_webpage(url: str):
|
||||||
|
|
||||||
@tool('render_webpage')
|
@tool('render_webpage')
|
||||||
def render_webpage_tool(url: str, reasoning: str):
|
def render_webpage_tool(url: str, reasoning: str):
|
||||||
"""Fetches the raw HTML of a webpage for use with the `retrieve_from_faiss` tool."""
|
"""Fetches the raw HTML of a webpage for use with the `retrieve_from_chroma` tool."""
|
||||||
if PRINT_USAGE:
|
if PRINT_USAGE:
|
||||||
_print_func_call('render_webpage', {'url': url, 'reasoning': reasoning})
|
_print_func_call('render_webpage', {'url': url, 'reasoning': reasoning})
|
||||||
html_source = render_webpage(url)
|
html_source = render_webpage(url)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.vectorstores.faiss import FAISS
|
from langchain_community.vectorstores.chroma import Chroma
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ class DocumentManager:
|
||||||
"""
|
"""
|
||||||
A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops.
|
A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops.
|
||||||
"""
|
"""
|
||||||
index: FAISS
|
index: Chroma
|
||||||
qa_chain: RetrievalQA
|
qa_chain: RetrievalQA
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -21,7 +21,7 @@ class DocumentManager:
|
||||||
assert isinstance(data, str)
|
assert isinstance(data, str)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||||
splits = text_splitter.split_text(data)
|
splits = text_splitter.split_text(data)
|
||||||
self.index = FAISS.from_texts(splits, self.embeddings)
|
self.index = Chroma.from_texts(splits, self.embeddings)
|
||||||
|
|
||||||
def create_retrieval(self):
|
def create_retrieval(self):
|
||||||
self.qa_chain = RetrievalQA.from_chain_type(
|
self.qa_chain = RetrievalQA.from_chain_type(
|
||||||
|
@ -29,13 +29,13 @@ class DocumentManager:
|
||||||
chain_type="map_reduce",
|
chain_type="map_reduce",
|
||||||
retriever=self.index.as_retriever(),
|
retriever=self.index.as_retriever(),
|
||||||
)
|
)
|
||||||
return {'success': True, 'msg': 'Use the `retrieve_from_faiss` to parse the result'}
|
return {'success': True, 'msg': 'Use the `retrieve_from_chroma` to parse the result'}
|
||||||
|
|
||||||
def retrieve(self, question: str):
|
def retrieve(self, question: str):
|
||||||
return self.qa_chain.invoke({"query": question})
|
return self.qa_chain.invoke({"query": question})
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def retrieve_from_faiss(question: str):
|
def retrieve_from_chroma(question: str):
|
||||||
"""Retrieve data from data embedded in FAISS"""
|
"""Retrieve data from data embedded in the Chroma DB."""
|
||||||
return GLOBALS.DocumentManager.retrieve(question)
|
return GLOBALS.DocumentManager.retrieve(question)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
termcolor==2.4.0
|
termcolor==2.4.0
|
||||||
faiss-cpu==1.7.4
|
|
||||||
serpapi==0.1.5
|
serpapi==0.1.5
|
||||||
langchain==0.1.5
|
langchain==0.1.5
|
||||||
langchain-openai==0.0.5
|
langchain-openai==0.0.5
|
||||||
|
@ -15,3 +14,4 @@ async-timeout==4.0.3
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
py-cpuinfo==9.0.0
|
py-cpuinfo==9.0.0
|
||||||
psutil==5.9.8
|
psutil==5.9.8
|
||||||
|
chromadb==0.4.22
|
4
run.py
4
run.py
|
@ -24,7 +24,7 @@ from pers.langchain.history import HistoryManager
|
||||||
from pers.langchain.tools.agent_end_response import end_my_response
|
from pers.langchain.tools.agent_end_response import end_my_response
|
||||||
from pers.langchain.tools.bash import run_bash
|
from pers.langchain.tools.bash import run_bash
|
||||||
from pers.langchain.tools.browser import render_webpage_tool
|
from pers.langchain.tools.browser import render_webpage_tool
|
||||||
from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_faiss
|
from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_chroma
|
||||||
from pers.langchain.tools.google import search_google, search_google_news, search_google_maps
|
from pers.langchain.tools.google import search_google, search_google_news, search_google_maps
|
||||||
from pers.langchain.tools.python import run_python
|
from pers.langchain.tools.python import run_python
|
||||||
from pers.langchain.tools.terminate import terminate_chat
|
from pers.langchain.tools.terminate import terminate_chat
|
||||||
|
@ -84,7 +84,7 @@ def main():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_faiss, read_webpage, render_webpage_tool, terminate_chat]
|
tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_chroma, read_webpage, render_webpage_tool, terminate_chat]
|
||||||
llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools])
|
llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools])
|
||||||
|
|
||||||
agent = (
|
agent = (
|
||||||
|
|
Reference in New Issue