From f4848423b828f8b8ab512ef1697a17bc9f719f42 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 4 Feb 2024 20:56:32 -0700 Subject: [PATCH] use chromadb instead of faiss --- pers/langchain/tools/browser.py | 2 +- pers/langchain/tools/document_manager.py | 12 ++++++------ requirements.txt | 2 +- run.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pers/langchain/tools/browser.py b/pers/langchain/tools/browser.py index 957ea1d..640ae65 100644 --- a/pers/langchain/tools/browser.py +++ b/pers/langchain/tools/browser.py @@ -32,7 +32,7 @@ def render_webpage(url: str): @tool('render_webpage') def render_webpage_tool(url: str, reasoning: str): - """Fetches the raw HTML of a webpage for use with the `retrieve_from_faiss` tool.""" + """Fetches the raw HTML of a webpage for use with the `retrieve_from_chroma` tool.""" if PRINT_USAGE: _print_func_call('render_webpage', {'url': url, 'reasoning': reasoning}) html_source = render_webpage(url) diff --git a/pers/langchain/tools/document_manager.py b/pers/langchain/tools/document_manager.py index c36a894..f217669 100644 --- a/pers/langchain/tools/document_manager.py +++ b/pers/langchain/tools/document_manager.py @@ -1,6 +1,6 @@ from langchain.chains import RetrievalQA from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores.faiss import FAISS +from langchain_community.vectorstores.chroma import Chroma from langchain_core.tools import tool from langchain_openai import OpenAIEmbeddings @@ -11,7 +11,7 @@ class DocumentManager: """ A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops. """ - index: FAISS + index: Chroma qa_chain: RetrievalQA def __init__(self): @@ -21,7 +21,7 @@ class DocumentManager: assert isinstance(data, str) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = text_splitter.split_text(data) - self.index = FAISS.from_texts(splits, self.embeddings) + self.index = Chroma.from_texts(splits, self.embeddings) def create_retrieval(self): self.qa_chain = RetrievalQA.from_chain_type( @@ -29,13 +29,13 @@ class DocumentManager: chain_type="map_reduce", retriever=self.index.as_retriever(), ) - return {'success': True, 'msg': 'Use the `retrieve_from_faiss` to parse the result'} + return {'success': True, 'msg': 'Use the `retrieve_from_chroma` to parse the result'} def retrieve(self, question: str): return self.qa_chain.invoke({"query": question}) @tool -def retrieve_from_faiss(question: str): - """Retrieve data from data embedded in FAISS""" +def retrieve_from_chroma(question: str): + """Retrieve data from data embedded in the Chroma DB.""" return GLOBALS.DocumentManager.retrieve(question) diff --git a/requirements.txt b/requirements.txt index a657953..f0e5252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ termcolor==2.4.0 -faiss-cpu==1.7.4 serpapi==0.1.5 langchain==0.1.5 langchain-openai==0.0.5 @@ -15,3 +14,4 @@ async-timeout==4.0.3 pyyaml==6.0.1 py-cpuinfo==9.0.0 psutil==5.9.8 +chromadb==0.4.22 \ No newline at end of file diff --git a/run.py b/run.py index 4bad0d0..b209aa8 100755 --- a/run.py +++ b/run.py @@ -24,7 +24,7 @@ from pers.langchain.history import HistoryManager from pers.langchain.tools.agent_end_response import end_my_response from pers.langchain.tools.bash import run_bash from pers.langchain.tools.browser import render_webpage_tool -from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_faiss +from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_chroma from pers.langchain.tools.google import search_google, search_google_news, search_google_maps from pers.langchain.tools.python import run_python from pers.langchain.tools.terminate import terminate_chat @@ -84,7 +84,7 @@ def main(): ] ) - tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_faiss, read_webpage, render_webpage_tool, terminate_chat] + tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_chroma, read_webpage, render_webpage_tool, terminate_chat] llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools]) agent = (