use chromadb instead of faiss

This commit is contained in:
Cyberes 2024-02-04 20:56:32 -07:00
parent 26a320dd2d
commit f4848423b8
4 changed files with 10 additions and 10 deletions

View File

@ -32,7 +32,7 @@ def render_webpage(url: str):
@tool('render_webpage') @tool('render_webpage')
def render_webpage_tool(url: str, reasoning: str): def render_webpage_tool(url: str, reasoning: str):
"""Fetches the raw HTML of a webpage for use with the `retrieve_from_faiss` tool.""" """Fetches the raw HTML of a webpage for use with the `retrieve_from_chroma` tool."""
if PRINT_USAGE: if PRINT_USAGE:
_print_func_call('render_webpage', {'url': url, 'reasoning': reasoning}) _print_func_call('render_webpage', {'url': url, 'reasoning': reasoning})
html_source = render_webpage(url) html_source = render_webpage(url)

View File

@ -1,6 +1,6 @@
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores.faiss import FAISS from langchain_community.vectorstores.chroma import Chroma
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
@ -11,7 +11,7 @@ class DocumentManager:
""" """
A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops. A class to manage loading large documents into the chain and giving the agent the ability to read it on subsequent loops.
""" """
index: FAISS index: Chroma
qa_chain: RetrievalQA qa_chain: RetrievalQA
def __init__(self): def __init__(self):
@ -21,7 +21,7 @@ class DocumentManager:
assert isinstance(data, str) assert isinstance(data, str)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_text(data) splits = text_splitter.split_text(data)
self.index = FAISS.from_texts(splits, self.embeddings) self.index = Chroma.from_texts(splits, self.embeddings)
def create_retrieval(self): def create_retrieval(self):
self.qa_chain = RetrievalQA.from_chain_type( self.qa_chain = RetrievalQA.from_chain_type(
@ -29,13 +29,13 @@ class DocumentManager:
chain_type="map_reduce", chain_type="map_reduce",
retriever=self.index.as_retriever(), retriever=self.index.as_retriever(),
) )
return {'success': True, 'msg': 'Use the `retrieve_from_faiss` to parse the result'} return {'success': True, 'msg': 'Use the `retrieve_from_chroma` to parse the result'}
def retrieve(self, question: str): def retrieve(self, question: str):
return self.qa_chain.invoke({"query": question}) return self.qa_chain.invoke({"query": question})
@tool @tool
def retrieve_from_faiss(question: str): def retrieve_from_chroma(question: str):
"""Retrieve data from data embedded in FAISS""" """Retrieve data from data embedded in the Chroma DB."""
return GLOBALS.DocumentManager.retrieve(question) return GLOBALS.DocumentManager.retrieve(question)

View File

@ -1,5 +1,4 @@
termcolor==2.4.0 termcolor==2.4.0
faiss-cpu==1.7.4
serpapi==0.1.5 serpapi==0.1.5
langchain==0.1.5 langchain==0.1.5
langchain-openai==0.0.5 langchain-openai==0.0.5
@ -15,3 +14,4 @@ async-timeout==4.0.3
pyyaml==6.0.1 pyyaml==6.0.1
py-cpuinfo==9.0.0 py-cpuinfo==9.0.0
psutil==5.9.8 psutil==5.9.8
chromadb==0.4.22

4
run.py
View File

@ -24,7 +24,7 @@ from pers.langchain.history import HistoryManager
from pers.langchain.tools.agent_end_response import end_my_response from pers.langchain.tools.agent_end_response import end_my_response
from pers.langchain.tools.bash import run_bash from pers.langchain.tools.bash import run_bash
from pers.langchain.tools.browser import render_webpage_tool from pers.langchain.tools.browser import render_webpage_tool
from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_faiss from pers.langchain.tools.document_manager import DocumentManager, retrieve_from_chroma
from pers.langchain.tools.google import search_google, search_google_news, search_google_maps from pers.langchain.tools.google import search_google, search_google_news, search_google_maps
from pers.langchain.tools.python import run_python from pers.langchain.tools.python import run_python
from pers.langchain.tools.terminate import terminate_chat from pers.langchain.tools.terminate import terminate_chat
@ -84,7 +84,7 @@ def main():
] ]
) )
tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_faiss, read_webpage, render_webpage_tool, terminate_chat] tools = [end_my_response, run_bash, run_python, search_google, search_google_maps, search_google_news, retrieve_from_chroma, read_webpage, render_webpage_tool, terminate_chat]
llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools]) llm_with_tools = pers.GLOBALS.OpenAI.bind(functions=[convert_to_openai_function(t) for t in tools])
agent = ( agent = (