Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 85 additions & 14 deletions src/vdf_io/export_vdf/weaviate_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from tqdm import tqdm
import weaviate
import json

from vdf_io.export_vdf.vdb_export_cls import ExportVDB
from vdf_io.meta_types import NamespaceMeta
from vdf_io.names import DBNames
from vdf_io.util import set_arg_from_input, set_arg_from_password
from typing import Dict, List

# Set these environment variables
URL = os.getenv("YOUR_WCS_URL")
APIKEY = os.getenv("YOUR_WCS_API_KEY")
OPENAI_APIKEY = os.getenv("OPENAI_APIKEY")


class ExportWeaviate(ExportVDB):
Expand All @@ -23,6 +27,15 @@ def make_parser(cls, subparsers):

parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
parser_weaviate.add_argument("--openai_api_key", type=str, help="Openai API key")
parser_weaviate.add_arguments(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method add_arguments is incorrect and will cause a runtime error as it does not exist in argparse. It should be add_argument.

Suggested change
parser_weaviate.add_arguments(
parser_weaviate.add_argument(

"--batch_size", type=int, help="batch size for fetching",
default=1000
)
parser_weaviate.add_argument(
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
help="Type of connection to Weaviate (local or cloud)"
)
parser_weaviate.add_argument(
"--classes", type=str, help="Classes to export (comma-separated)"
)
Expand All @@ -35,6 +48,12 @@ def export_vdb(cls, args):
"Enter the URL of Weaviate instance: ",
str,
)
set_arg_from_input(
args,
"connection_type",
"Enter 'local' or 'cloud' for connection types: ",
choices=['local', 'cloud'],
)
set_arg_from_password(
args,
"api_key",
Expand All @@ -55,14 +74,20 @@ def export_vdb(cls, args):
weaviate_export.get_data()
return weaviate_export

# Connect to a WCS instance
# Connect to a WCS or local instance
def __init__(self, args):
super().__init__(args)
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
skip_init_checks=True,
)
if self.args["connection_type"] == "local":
self.client = weaviate.connect_to_local()
else:
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
if self.args["openai_api_key"]
else None,
skip_init_checks=True,
)

def get_index_names(self):
if self.args.get("classes") is None:
Expand All @@ -76,14 +101,60 @@ def get_index_names(self):
return [c for c in self.all_classes if c in input_classes]

def get_data(self):
# Get all objects of a class
# Get the index names to export
index_names = self.get_index_names()
for class_name in index_names:
collection = self.client.collections.get(class_name)
index_metas: Dict[str, List[NamespaceMeta]] = {}

# Iterate over index names and fetch data
for index_name in index_names:
collection = self.client.collections.get(index_name)
response = collection.aggregate.over_all(total_count=True)
print(f"{response.total_count=}")
total_vector_count = response.total_count

# Create vectors directory for this index
vectors_directory = self.create_vec_dir(index_name)

# Export data in batches
batch_size = self.args["batch_size"]
num_batches = (total_vector_count + batch_size - 1) // batch_size
num_vectors_exported = 0

for batch_idx in tqdm(range(num_batches), desc=f"Exporting {index_name}"):
offset = batch_idx * batch_size
objects = collection.objects.limit(batch_size).offset(offset).get()

# Extract vectors and metadata
vectors = {obj.id: obj.vector for obj in objects}
metadata = {}
# Need a better way
for obj in objects:
metadata[obj.id] = {attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("__")}


# Save vectors and metadata to Parquet file
num_vectors_exported += self.save_vectors_to_parquet(
vectors, metadata, vectors_directory
)

# Create NamespaceMeta for this index
namespace_metas = [
self.get_namespace_meta(
index_name,
vectors_directory,
total=total_vector_count,
num_vectors_exported=num_vectors_exported,
dim=300, # Not sure of the dimensions
distance="Cosine",
)
]
index_metas[index_name] = namespace_metas

# Write VDFMeta to JSON file
self.file_structure.append(os.path.join(self.vdf_directory, "VDF_META.json"))
internal_metadata = self.get_basic_vdf_meta(index_metas)
meta_text = json.dumps(internal_metadata.model_dump(), indent=4)
tqdm.write(meta_text)

print("Data export complete.")

# objects = self.client.query.get(
# wvq.Objects(wvq.Class(class_name)).with_limit(1000)
# )
# print(objects)
return True
139 changes: 139 additions & 0 deletions src/vdf_io/import_vdf/weaviate_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import weaviate
import json
from tqdm import tqdm
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
from vdf_io.names import DBNames
from vdf_io.util import set_arg_from_input, set_arg_from_password

# Set these environment variables
URL = os.getenv("YOUR_WCS_URL")
APIKEY = os.getenv("YOUR_WCS_API_KEY")


class ImportWeaviate(ImportVDB):
DB_NAME_SLUG = DBNames.WEAVIATE

@classmethod
def make_parser(cls, subparsers):
parser_weaviate = subparsers.add_parser(
cls.DB_NAME_SLUG, help="Import data into Weaviate"
)

parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
parser_weaviate.add_argument(
"--index_name", type=str, help="Name of the index in Weaviate"
)

@classmethod
def import_vdb(cls, args):
set_arg_from_input(
args,
"url",
"Enter the URL of Weaviate instance: ",
str,
)
set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)
set_arg_from_input(
args,
"index_name",
"Enter the name of the index in Weaviate: ",
str,
)
weaviate_import = ImportWeaviate(args)
weaviate_import.upsert_data()
return weaviate_import

def __init__(self, args):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection_type argument is used here but it is not defined in the argument parser for the import script. This will cause an error when trying to access self.args["connection_type"].

To fix this, add the connection_type argument to the parser in the make_parser method:

Suggested change
def __init__(self, args):
parser_weaviate.add_argument(
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
help="Type of connection to Weaviate (local or cloud)"
)

super().__init__(args)
if self.args["connection_type"] == "local":
self.client = weaviate.connect_to_local()
else:
self.client = weaviate.connect_to_wcs(
cluster_url=self.args["url"],
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
if self.args["openai_api_key"]
else None,
skip_init_checks=True,
)

def upsert_data(self):
max_hit = False
total_imported_count = 0

# Iterate over the indexes and import the data
for index_name, index_meta in tqdm(self.vdf_meta["indexes"].items(), desc="Importing indexes"):
tqdm.write(f"Importing data for index '{index_name}'")
for namespace_meta in index_meta:
self.set_dims(namespace_meta, index_name)

# Create or get the index
index_name = self.create_new_name(index_name, self.client.collections.list_all().keys())
index = self.client.collections.get(index_name)

# Load data from the Parquet files
data_path = namespace_meta["data_path"]
final_data_path = self.get_final_data_path(data_path)
parquet_files = self.get_parquet_files(final_data_path)

vectors = {}
metadata = {}

# for file in tqdm(parquet_files, desc="Loading data from parquet files"):
# file_path = os.path.join(final_data_path, file)
# df = self.read_parquet_progress(file_path)

# if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
# max_hit = True
# break

# self.update_vectors(vectors, vector_column_name, df)
# self.update_metadata(metadata, vector_column_names, df)
# if max_hit:
# break

# tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")

# # Upsert the vectors and metadata to the Weaviate index in batches
# BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
# current_batch_size = BATCH_SIZE
# start_idx = 0

# while start_idx < len(vectors):
# end_idx = min(start_idx + current_batch_size, len(vectors))

# batch_vectors = [
# (
# str(id),
# vector,
# {
# k: v
# for k, v in metadata.get(id, {}).items()
# if v is not None
# } if len(metadata.get(id, {}).keys()) > 0 else None
# )
# for id, vector in list(vectors.items())[start_idx:end_idx]
# ]

# try:
# resp = index.batch.create(batch_vectors)
# total_imported_count += len(batch_vectors)
# start_idx += len(batch_vectors)
# except Exception as e:
# tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
# if current_batch_size < BATCH_SIZE / 100:
# tqdm.write("Batch size is not the issue. Aborting import")
# raise e
# current_batch_size = int(2 * current_batch_size / 3)
# tqdm.write(f"Reducing batch size to {current_batch_size}")
# continue

# tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
# self.args["imported_count"] = total_imported_count