Skip to content
This repository has been archived by the owner on Nov 9, 2024. It is now read-only.

Commit

Permalink
Add index name parameter to agent (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
umutunsal authored Oct 8, 2024
1 parent 25dbd18 commit 2e08d7e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
17 changes: 14 additions & 3 deletions hive_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
retrieve=False,
required_exts=supported_exts,
retrieval_tool="basic",
index_name : Optional[str] = None,
load_index_file=False,
swarm_mode=False,
chat_only_mode=False,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
self.retrieve = retrieve
self.required_exts = required_exts
self.retrieval_tool = retrieval_tool
self.index_name = index_name
self.load_index_file = load_index_file
logging.basicConfig(stream=sys.stdout, level=self.__config.get("log"))
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
Expand Down Expand Up @@ -277,17 +279,26 @@ def add_batch_indexes(self):

if "chroma" in self.retrieval_tool:
chroma_retriever = ChromaRetriever()
index, file_names = chroma_retriever.create_index()
if self.index_name is not None:
index, file_names = chroma_retriever.create_index(collection_name=self.index_name)
else:
index, file_names = chroma_retriever.create_index()
self.index_store.add_index(chroma_retriever.name, index, file_names)

if "pinecone-serverless" in self.retrieval_tool:
pinecone_retriever = PineconeRetriever()
index, file_names = pinecone_retriever.create_serverless_index()
if self.index_name is not None:
index, file_names = pinecone_retriever.create_serverless_index(collection_name=self.index_name)
else:
index, file_names = pinecone_retriever.create_serverless_index()
self.index_store.add_index(pinecone_retriever.name, index, file_names)

if "pinecone-pod" in self.retrieval_tool:
pinecone_retriever = PineconeRetriever()
index, file_names = pinecone_retriever.create_pod_index()
if self.index_name is not None:
index, file_names = pinecone_retriever.create_pod_index(collection_name=self.index_name)
else:
index, file_names = pinecone_retriever.create_pod_index()
self.index_store.add_index(pinecone_retriever.name, index, file_names)

self.index_store.save_to_file()
Expand Down
12 changes: 6 additions & 6 deletions hive_agent/tools/retriever/pinecone_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def create_serverless_index(
folder_path=None,
prefix='',
bucket=None,
name="hive-agent-pinecone",
collection_name="hive-agent-pinecone",
dimension=1536,
metric="euclidean",
cloud="aws",
Expand All @@ -45,12 +45,12 @@ def create_serverless_index(
else:
documents, file_names = self._load_documents(file_path, folder_path)
self.pinecone_client.create_index(
name=name,
name=collection_name,
dimension=dimension,
metric=metric,
spec=ServerlessSpec(cloud=cloud, region=region),
)
pinecone_index = self.pinecone_client.Index(name)
pinecone_index = self.pinecone_client.Index(collection_name)

vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
Expand All @@ -65,7 +65,7 @@ def create_pod_index(
folder_path=None,
s3_endpoint_url=None,
bucket=None,
name="hive-agent-pinecone-pod",
collection_name="hive-agent-pinecone-pod",
dimension=1536,
metric="cosine",
environment="us-east1-gcp",
Expand All @@ -77,12 +77,12 @@ def create_pod_index(
else:
documents, file_names = self._load_documents(file_path, folder_path)
self.pinecone_client.create_index(
name=name,
name=collection_name,
dimension=dimension,
metric=metric,
spec=PodSpec(environment=environment, pod_type=pod_type, pods=pods),
)
pinecone_index = self.pinecone_client.Index(name)
pinecone_index = self.pinecone_client.Index(collection_name)

vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
Expand Down

0 comments on commit 2e08d7e

Please sign in to comment.