22
33from tqdm import tqdm
44import weaviate
5+ import json
56
67from vdf_io .export_vdf .vdb_export_cls import ExportVDB
8+ from vdf_io .meta_types import NamespaceMeta
79from vdf_io .names import DBNames
810from vdf_io .util import set_arg_from_input , set_arg_from_password
11+ from typing import Dict , List
912
1013# Set these environment variables
1114URL = os .getenv ("YOUR_WCS_URL" )
1215APIKEY = os .getenv ("YOUR_WCS_API_KEY" )
16+ OPENAI_APIKEY = os .getenv ("OPENAI_APIKEY" )
1317
1418
1519class ExportWeaviate (ExportVDB ):
@@ -23,6 +27,15 @@ def make_parser(cls, subparsers):
2327
2428 parser_weaviate .add_argument ("--url" , type = str , help = "URL of Weaviate instance" )
2529 parser_weaviate .add_argument ("--api_key" , type = str , help = "Weaviate API key" )
30+ parser_weaviate .add_argument ("--openai_api_key" , type = str , help = "Openai API key" )
31+ parser_weaviate .add_arguments (
32+ "--batch_size" , type = int , help = "batch size for fetching" ,
33+ default = 1000
34+ )
35+ parser_weaviate .add_argument (
36+ "--connection-type" , type = str , choices = ["local" , "cloud" ], default = "cloud" ,
37+ help = "Type of connection to Weaviate (local or cloud)"
38+ )
2639 parser_weaviate .add_argument (
2740 "--classes" , type = str , help = "Classes to export (comma-separated)"
2841 )
@@ -35,6 +48,12 @@ def export_vdb(cls, args):
3548 "Enter the URL of Weaviate instance: " ,
3649 str ,
3750 )
51+ set_arg_from_input (
52+ args ,
53+ "connection_type" ,
54+ "Enter 'local' or 'cloud' for connection types: " ,
55+ choices = ['local' , 'cloud' ],
56+ )
3857 set_arg_from_password (
3958 args ,
4059 "api_key" ,
@@ -55,14 +74,20 @@ def export_vdb(cls, args):
5574 weaviate_export .get_data ()
5675 return weaviate_export
5776
58- # Connect to a WCS instance
77+ # Connect to a WCS or local instance
5978 def __init__ (self , args ):
6079 super ().__init__ (args )
61- self .client = weaviate .connect_to_wcs (
62- cluster_url = self .args ["url" ],
63- auth_credentials = weaviate .auth .AuthApiKey (self .args ["api_key" ]),
64- skip_init_checks = True ,
65- )
80+ if self .args ["connection_type" ] == "local" :
81+ self .client = weaviate .connect_to_local ()
82+ else :
83+ self .client = weaviate .connect_to_wcs (
84+ cluster_url = self .args ["url" ],
85+ auth_credentials = weaviate .auth .AuthApiKey (self .args ["api_key" ]),
86+ headers = {'X-OpenAI-Api-key' : self .args ["openai_api_key" ]}
87+ if self .args ["openai_api_key" ]
88+ else None ,
89+ skip_init_checks = True ,
90+ )
6691
6792 def get_index_names (self ):
6893 if self .args .get ("classes" ) is None :
@@ -76,14 +101,60 @@ def get_index_names(self):
76101 return [c for c in self .all_classes if c in input_classes ]
77102
78103 def get_data (self ):
79- # Get all objects of a class
104+ # Get the index names to export
80105 index_names = self .get_index_names ()
81- for class_name in index_names :
82- collection = self .client .collections .get (class_name )
106+ index_metas : Dict [str , List [NamespaceMeta ]] = {}
107+
108+ # Iterate over index names and fetch data
109+ for index_name in index_names :
110+ collection = self .client .collections .get (index_name )
83111 response = collection .aggregate .over_all (total_count = True )
84- print (f"{ response .total_count = } " )
112+ total_vector_count = response .total_count
113+
114+ # Create vectors directory for this index
115+ vectors_directory = self .create_vec_dir (index_name )
116+
117+ # Export data in batches
118+ batch_size = self .args ["batch_size" ]
119+ num_batches = (total_vector_count + batch_size - 1 ) // batch_size
120+ num_vectors_exported = 0
121+
122+ for batch_idx in tqdm (range (num_batches ), desc = f"Exporting { index_name } " ):
123+ offset = batch_idx * batch_size
124+ objects = collection .objects .limit (batch_size ).offset (offset ).get ()
125+
126+ # Extract vectors and metadata
127+ vectors = {obj .id : obj .vector for obj in objects }
128+ metadata = {}
129+ # Need a better way
130+ for obj in objects :
131+ metadata [obj .id ] = {attr : getattr (obj , attr ) for attr in dir (obj ) if not attr .startswith ("__" )}
132+
133+
134+ # Save vectors and metadata to Parquet file
135+ num_vectors_exported += self .save_vectors_to_parquet (
136+ vectors , metadata , vectors_directory
137+ )
138+
139+ # Create NamespaceMeta for this index
140+ namespace_metas = [
141+ self .get_namespace_meta (
142+ index_name ,
143+ vectors_directory ,
144+ total = total_vector_count ,
145+ num_vectors_exported = num_vectors_exported ,
146+ dim = 300 , # Not sure of the dimensions
147+ distance = "Cosine" ,
148+ )
149+ ]
150+ index_metas [index_name ] = namespace_metas
151+
152+ # Write VDFMeta to JSON file
153+ self .file_structure .append (os .path .join (self .vdf_directory , "VDF_META.json" ))
154+ internal_metadata = self .get_basic_vdf_meta (index_metas )
155+ meta_text = json .dumps (internal_metadata .model_dump (), indent = 4 )
156+ tqdm .write (meta_text )
157+
158+ print ("Data export complete." )
85159
86- # objects = self.client.query.get(
87- # wvq.Objects(wvq.Class(class_name)).with_limit(1000)
88- # )
89- # print(objects)
160+ return True
0 commit comments