diff --git a/src/vdf_io/export_vdf/weaviate_export.py b/src/vdf_io/export_vdf/weaviate_export.py index 518dd3c..b71d10a 100644 --- a/src/vdf_io/export_vdf/weaviate_export.py +++ b/src/vdf_io/export_vdf/weaviate_export.py @@ -1,15 +1,17 @@ import os - -from tqdm import tqdm import weaviate +import json +from typing import Dict, List +from tqdm import tqdm + +from weaviate.classes.query import MetadataQuery from vdf_io.export_vdf.vdb_export_cls import ExportVDB +from vdf_io.meta_types import NamespaceMeta from vdf_io.names import DBNames -from vdf_io.util import set_arg_from_input, set_arg_from_password - -# Set these environment variables -URL = os.getenv("YOUR_WCS_URL") -APIKEY = os.getenv("YOUR_WCS_API_KEY") +from vdf_io.util import set_arg_from_input +from vdf_io.constants import DEFAULT_BATCH_SIZE +from vdf_io.weaviate_util import prompt_for_creds class ExportWeaviate(ExportVDB): @@ -23,24 +25,32 @@ def make_parser(cls, subparsers): parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance") parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key") + parser_weaviate.add_argument( + "--openai_api_key", type=str, help="Openai API key" + ) + parser_weaviate.add_argument( + "--batch_size", + type=int, + help="batch size for fetching", + default=DEFAULT_BATCH_SIZE, + ) + parser_weaviate.add_argument( + "--offset", type=int, help="offset for fetching", default=None + ) + parser_weaviate.add_argument( + "--connection-type", + type=str, + choices=["local", "cloud"], + default="cloud", + help="Type of connection to Weaviate (local or cloud)", + ) parser_weaviate.add_argument( "--classes", type=str, help="Classes to export (comma-separated)" ) @classmethod def export_vdb(cls, args): - set_arg_from_input( - args, - "url", - "Enter the URL of Weaviate instance: ", - str, - ) - set_arg_from_password( - args, - "api_key", - "Enter the Weaviate API key: ", - "WEAVIATE_API_KEY", - ) + prompt_for_creds(args) weaviate_export = ExportWeaviate(args) weaviate_export.all_classes = list( weaviate_export.client.collections.list_all().keys() @@ -55,14 +65,20 @@ def export_vdb(cls, args): weaviate_export.get_data() return weaviate_export - # Connect to a WCS instance + # Connect to a WCS or local instance def __init__(self, args): super().__init__(args) - self.client = weaviate.connect_to_wcs( - cluster_url=self.args["url"], - auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]), - skip_init_checks=True, - ) + if self.args["connection_type"] == "local": + self.client = weaviate.connect_to_local() + else: + self.client = weaviate.connect_to_wcs( + cluster_url=self.args["url"], + auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]), + headers={"X-OpenAI-Api-key": self.args["openai_api_key"]} + if self.args["openai_api_key"] + else None, + skip_init_checks=True, + ) def get_index_names(self): if self.args.get("classes") is None: @@ -75,15 +91,73 @@ def get_index_names(self): ) return [c for c in self.all_classes if c in input_classes] + def metadata_to_dict(self, metadata): + meta_data = {} + meta_data["creation_time"] = metadata.creation_time + meta_data["distance"] = metadata.distance + meta_data["certainty"] = metadata.certainty + meta_data["explain_score"] = metadata.explain_score + meta_data["is_consistent"] = metadata.is_consistent + meta_data["last_update_time"] = metadata.last_update_time + meta_data["rerank_score"] = metadata.rerank_score + meta_data["score"] = metadata.score + + return meta_data + def get_data(self): - # Get all objects of a class + # Get the index names to export index_names = self.get_index_names() - for class_name in index_names: - collection = self.client.collections.get(class_name) - response = collection.aggregate.over_all(total_count=True) - print(f"{response.total_count=}") - - # objects = self.client.query.get( - # wvq.Objects(wvq.Class(class_name)).with_limit(1000) - # ) - # print(objects) + index_metas: Dict[str, List[NamespaceMeta]] = {} + + # Export data in batches + batch_size = self.args["batch_size"] + offset = self.args["offset"] + + # Iterate over index names and fetch data + for index_name in index_names: + collection = self.client.collections.get(index_name) + response = collection.query.fetch_objects( + limit=batch_size, + offset=offset, + include_vector=True, + return_metadata=MetadataQuery.full(), + ) + res = collection.aggregate.over_all(total_count=True) + total_vector_count = res.total_count + + # Create vectors directory for this index + vectors_directory = self.create_vec_dir(index_name) + + for obj in response.objects: + vectors = obj.vector + metadata = obj.metadata + metadata = self.metadata_to_dict(metadata=metadata) + + # Save vectors and metadata to Parquet file + num_vectors_exported = self.save_vectors_to_parquet( + vectors, metadata, vectors_directory + ) + + # Create NamespaceMeta for this index + namespace_metas = [ + self.get_namespace_meta( + index_name, + vectors_directory, + total=total_vector_count, + num_vectors_exported=num_vectors_exported, + dim=-1, + distance="Cosine", + ) + ] + index_metas[index_name] = namespace_metas + + # Write VDFMeta to JSON file + self.file_structure.append(os.path.join(self.vdf_directory, "VDF_META.json")) + internal_metadata = self.get_basic_vdf_meta(index_metas) + meta_text = json.dumps(internal_metadata.model_dump(), indent=4) + tqdm.write(meta_text) + with open(os.path.join(self.vdf_directory, "VDF_META.json"), "w") as json_file: + json_file.write(meta_text) + print("Data export complete.") + + return True diff --git a/src/vdf_io/import_vdf/weaviate_import.py b/src/vdf_io/import_vdf/weaviate_import.py new file mode 100644 index 0000000..f28befc --- /dev/null +++ b/src/vdf_io/import_vdf/weaviate_import.py @@ -0,0 +1,122 @@ +import os +import weaviate +from tqdm import tqdm +from vdf_io.import_vdf.vdf_import_cls import ImportVDB +from vdf_io.names import DBNames +from vdf_io.constants import INT_MAX, DEFAULT_BATCH_SIZE +from vdf_io.weaviate_util import prompt_for_creds + +# Set these environment variables +URL = os.getenv("YOUR_WCS_URL") +APIKEY = os.getenv("YOUR_WCS_API_KEY") + + +class ImportWeaviate(ImportVDB): + DB_NAME_SLUG = DBNames.WEAVIATE + + @classmethod + def make_parser(cls, subparsers): + parser_weaviate = subparsers.add_parser( + cls.DB_NAME_SLUG, help="Import data into Weaviate" + ) + + parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance") + parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key") + parser_weaviate.add_argument( + "--connection-type", + type=str, + choices=["local", "cloud"], + default="cloud", + help="Type of connection to Weaviate (local or cloud)", + ) + parser_weaviate.add_argument( + "--batch_size", + type=int, + help="batch size for fetching", + default=DEFAULT_BATCH_SIZE, + ) + + @classmethod + def import_vdb(cls, args): + prompt_for_creds(args) + weaviate_import = ImportWeaviate(args) + weaviate_import.upsert_data() + return weaviate_import + + def __init__(self, args): + super().__init__(args) + if self.args["connection_type"] == "local": + self.client = weaviate.connect_to_local() + else: + self.client = weaviate.connect_to_wcs( + cluster_url=self.args["url"], + auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]), + headers={"X-OpenAI-Api-key": self.args.get("openai_api_key", "")}, + skip_init_checks=True, + ) + + def upsert_data(self): + max_hit = False + total_imported_count = 0 + + # Iterate over the indexes and import the data + for index_name, index_meta in tqdm( + self.vdf_meta["indexes"].items(), desc="Importing indexes" + ): + tqdm.write(f"Importing data for index '{index_name}'") + for namespace_meta in index_meta: + self.set_dims(namespace_meta, index_name) + + # Create or get the index + index_name = self.create_new_name( + index_name, self.client.collections.list_all().keys() + ) + + # Load data from the Parquet files + data_path = namespace_meta["data_path"] + final_data_path = self.get_final_data_path(data_path) + parquet_files = self.get_parquet_files(final_data_path) + + vectors = {} + metadata = {} + vector_column_names, vector_column_name = self.get_vector_column_name( + index_name, namespace_meta + ) + + for file in tqdm(parquet_files, desc="Loading data from parquet files"): + file_path = os.path.join(final_data_path, file) + df = self.read_parquet_progress(file_path) + + if len(vectors) > (self.args.get("max_num_rows") or INT_MAX): + max_hit = True + break + if len(vectors) + len(df) > (self.args.get("max_num_rows") or INT_MAX): + df = df.head( + (self.args.get("max_num_rows") or INT_MAX) - len(vectors) + ) + max_hit = True + self.update_vectors(vectors, vector_column_name, df) + self.update_metadata(metadata, vector_column_names, df) + if max_hit: + break + + tqdm.write( + f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files" + ) + + # Upsert the vectors and metadata to the Weaviate index in batches + BATCH_SIZE = self.args.get("batch_size") + + with self.client.batch.fixed_size(batch_size=BATCH_SIZE) as batch: + for _, vector in vectors.items(): + batch.add_object( + vector=vector, + collection=index_name, + # TODO: Find way to add Metadata + ) + total_imported_count += 1 + + tqdm.write( + f"Data import completed successfully. Imported {total_imported_count} vectors" + ) + self.args["imported_count"] = total_imported_count diff --git a/src/vdf_io/weaviate_util.py b/src/vdf_io/weaviate_util.py new file mode 100644 index 0000000..827b324 --- /dev/null +++ b/src/vdf_io/weaviate_util.py @@ -0,0 +1,31 @@ +from vdf_io.util import set_arg_from_input, set_arg_from_password + + +def prompt_for_creds(args): + set_arg_from_input( + args, + "connection_type", + "Enter 'local' or 'cloud' for connection types: ", + choices=["local", "cloud"], + ) + if args["connection_type"] == "cloud": + set_arg_from_input( + args, + "url", + "Enter the URL of Weaviate instance: ", + str, + env_var="WEAVIATE_URL", + ) + set_arg_from_password( + args, + "api_key", + "Enter the Weaviate API key: ", + "WEAVIATE_API_KEY", + ) + + set_arg_from_password( + args, + "api_key", + "Enter the Weaviate API key: ", + "WEAVIATE_API_KEY", + )