11import os
22import weaviate
3- import json
43from tqdm import tqdm
54from vdf_io .import_vdf .vdf_import_cls import ImportVDB
65from vdf_io .names import DBNames
76from vdf_io .util import set_arg_from_input , set_arg_from_password
7+ from vdf_io .constants import INT_MAX , DEFAULT_BATCH_SIZE
88
99# Set these environment variables
1010URL = os .getenv ("YOUR_WCS_URL" )
@@ -25,6 +25,14 @@ def make_parser(cls, subparsers):
2525 parser_weaviate .add_argument (
2626 "--index_name" , type = str , help = "Name of the index in Weaviate"
2727 )
28+ parser_weaviate .add_argument (
29+ "--connection-type" , type = str , choices = ["local" , "cloud" ], default = "cloud" ,
30+ help = "Type of connection to Weaviate (local or cloud)"
31+ )
32+ parser_weaviate .add_argument (
33+ "--batch_size" , type = int , help = "batch size for fetching" ,
34+ default = DEFAULT_BATCH_SIZE
35+ )
2836
2937 @classmethod
3038 def import_vdb (cls , args ):
@@ -34,18 +42,24 @@ def import_vdb(cls, args):
3442 "Enter the URL of Weaviate instance: " ,
3543 str ,
3644 )
37- set_arg_from_password (
38- args ,
39- "api_key" ,
40- "Enter the Weaviate API key: " ,
41- "WEAVIATE_API_KEY" ,
42- )
4345 set_arg_from_input (
4446 args ,
4547 "index_name" ,
4648 "Enter the name of the index in Weaviate: " ,
4749 str ,
4850 )
51+ set_arg_from_input (
52+ args ,
53+ "connection_type" ,
54+ "Enter 'local' or 'cloud' for connection types: " ,
55+ choices = ['local' , 'cloud' ],
56+ )
57+ set_arg_from_password (
58+ args ,
59+ "api_key" ,
60+ "Enter the Weaviate API key: " ,
61+ "WEAVIATE_API_KEY" ,
62+ )
4963 weaviate_import = ImportWeaviate (args )
5064 weaviate_import .upsert_data ()
5165 return weaviate_import
@@ -76,7 +90,6 @@ def upsert_data(self):
7690
7791 # Create or get the index
7892 index_name = self .create_new_name (index_name , self .client .collections .list_all ().keys ())
79- index = self .client .collections .get (index_name )
8093
8194 # Load data from the Parquet files
8295 data_path = namespace_meta ["data_path" ]
@@ -85,55 +98,43 @@ def upsert_data(self):
8598
8699 vectors = {}
87100 metadata = {}
101+ vector_column_names , vector_column_name = self .get_vector_column_name (
102+ index_name , namespace_meta
103+ )
88104
89- # for file in tqdm(parquet_files, desc="Loading data from parquet files"):
90- # file_path = os.path.join(final_data_path, file)
91- # df = self.read_parquet_progress(file_path)
92-
93- # if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
94- # max_hit = True
95- # break
96-
97- # self.update_vectors(vectors, vector_column_name, df)
98- # self.update_metadata(metadata, vector_column_names, df)
99- # if max_hit:
100- # break
101-
102- # tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")
103-
104- # # Upsert the vectors and metadata to the Weaviate index in batches
105- # BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
106- # current_batch_size = BATCH_SIZE
107- # start_idx = 0
108-
109- # while start_idx < len(vectors):
110- # end_idx = min(start_idx + current_batch_size, len(vectors))
111-
112- # batch_vectors = [
113- # (
114- # str(id),
115- # vector,
116- # {
117- # k: v
118- # for k, v in metadata.get(id, {}).items()
119- # if v is not None
120- # } if len(metadata.get(id, {}).keys()) > 0 else None
121- # )
122- # for id, vector in list(vectors.items())[start_idx:end_idx]
123- # ]
124-
125- # try:
126- # resp = index.batch.create(batch_vectors)
127- # total_imported_count += len(batch_vectors)
128- # start_idx += len(batch_vectors)
129- # except Exception as e:
130- # tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
131- # if current_batch_size < BATCH_SIZE / 100:
132- # tqdm.write("Batch size is not the issue. Aborting import")
133- # raise e
134- # current_batch_size = int(2 * current_batch_size / 3)
135- # tqdm.write(f"Reducing batch size to {current_batch_size}")
136- # continue
137-
138- # tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
139- # self.args["imported_count"] = total_imported_count
105+ for file in tqdm (parquet_files , desc = "Loading data from parquet files" ):
106+ file_path = os .path .join (final_data_path , file )
107+ df = self .read_parquet_progress (file_path )
108+
109+ if len (vectors ) > (self .args .get ("max_num_rows" ) or INT_MAX ):
110+ max_hit = True
111+ break
112+ if len (vectors ) + len (df ) > (
113+ self .args .get ("max_num_rows" ) or INT_MAX
114+ ):
115+ df = df .head (
116+ (self .args .get ("max_num_rows" ) or INT_MAX ) - len (vectors )
117+ )
118+ max_hit = True
119+ self .update_vectors (vectors , vector_column_name , df )
120+ self .update_metadata (metadata , vector_column_names , df )
121+ if max_hit :
122+ break
123+
124+ tqdm .write (f"Loaded { len (vectors )} vectors from { len (parquet_files )} parquet files" )
125+
126+ # Upsert the vectors and metadata to the Weaviate index in batches
127+ BATCH_SIZE = self .args .get ("batch_size" )
128+
129+ with self .client .batch .fixed_size (batch_size = BATCH_SIZE ) as batch :
130+ for _ , vector in vectors .items ():
131+ batch .add_object (
132+ vector = vector ,
133+ collection = index_name
134+ #TODO: Find way to add Metadata
135+ )
136+ total_imported_count += 1
137+
138+
139+ tqdm .write (f"Data import completed successfully. Imported { total_imported_count } vectors" )
140+ self .args ["imported_count" ] = total_imported_count
0 commit comments