@@ -183,66 +183,83 @@ def save_training_session(config, model_location, perf_location, prediction_labe
183
183
upload_data_post_training (prediction_labels_location , classification_name_to_class_id , training_id , config .get ("tagging_user" ), config .get ("url" ))
184
184
185
185
def upload_model_to_blob_storage (config , model_location , file_name , user_name ):
186
- blob_storage = BlobStorage .get_azure_storage_client (config )
187
- blob_metadata = {
188
- "userFilePath" : model_location ,
189
- "uploadUser" : user_name
190
- }
191
- uri = 'https://' + config .get ("storage_account" ) + '.blob.core.windows.net/' + config .get ("storage_container" ) + '/' + file_name
192
- blob_storage .create_blob_from_path (
193
- config .get ("storage_container" ),
194
- file_name ,
195
- model_location ,
196
- metadata = blob_metadata
197
- )
198
- print ("Model uploaded at " + str (uri ))
199
- return uri
186
+ try :
187
+ blob_storage = BlobStorage .get_azure_storage_client (config )
188
+ blob_metadata = {
189
+ "userFilePath" : model_location ,
190
+ "uploadUser" : user_name
191
+ }
192
+ uri = 'https://' + config .get ("storage_account" ) + '.blob.core.windows.net/' + config .get ("storage_container" ) + '/' + file_name
193
+ blob_storage .create_blob_from_path (
194
+ config .get ("storage_container" ),
195
+ file_name ,
196
+ model_location ,
197
+ metadata = blob_metadata
198
+ )
199
+ print ("Model uploaded at " + str (uri ))
200
+ return uri
201
+ except Exception as e :
202
+ print ("Issue uploading model to cloud storage: {}" ,e )
200
203
201
204
def construct_new_training_session (perf_location , classification_name_to_class_id , overall_average , training_description , model_location , avg_dictionary , user_name , function_url ):
202
- training_session = TrainingSession (training_description , model_location , overall_average , avg_dictionary )
203
- query = {
204
- "userName" : user_name
205
- }
206
- function_url = function_url + "/api/train"
207
- payload = jsonpickle .encode (training_session , unpicklable = False )
208
- response = requests .post (function_url , params = query , json = payload )
209
- training_id = int (response .json ())
210
- print ("Created a new training session with id: " + str (training_id ))
211
- return training_id
205
+ try :
206
+ training_session = TrainingSession (training_description , model_location , overall_average , avg_dictionary )
207
+ query = {
208
+ "userName" : user_name
209
+ }
210
+ function_url = function_url + "/api/train"
211
+ payload = jsonpickle .encode (training_session , unpicklable = False )
212
+ response = requests .post (function_url , params = query , json = payload )
213
+ response .raise_for_status ()
214
+ training_id = int (response .json ())
215
+ print ("Created a new training session with id: " + str (training_id ))
216
+ return training_id
217
+ except requests .exceptions .HTTPError as e :
218
+ print ("HTTP Error when saving training session: {}" ,e .response .content )
219
+ raise
220
+ except Exception as e :
221
+ print ("Issue saving training session: {}" ,e )
212
222
213
223
def process_classifications (perf_location , user_name ,function_url ):
214
- # First build query string to get classification map
215
- classes = ""
216
- query = {
217
- "userName" : user_name
218
- }
219
- function_url = function_url + "/api/classification"
220
- overall_average = 0.0
221
- with open (perf_location ) as f :
222
- content = csv .reader (f , delimiter = ',' )
223
- next (content , None ) #Skip header
224
- for line in content :
225
- class_name = line [0 ].strip ()
226
- if class_name == "Average" :
227
- overall_average = line [1 ]
228
- elif class_name not in classes and class_name != "NULL" :
229
- classes = classes + class_name + ","
230
-
231
- query ["className" ] = classes [:- 1 ]
232
- print ("Getting classification map for classes " + query ["className" ])
233
- response = requests .get (function_url , params = query )
234
- classification_name_to_class_id = response .json ()
235
-
236
- # Now that we have classification map, build the dictionary that maps class id : average
237
- avg_dictionary = {}
238
- with open (perf_location ) as csvfile :
239
- reader = csv .reader (csvfile , delimiter = ',' )
240
- next (reader , None ) #Skip header
241
- for row in reader :
242
- if row [0 ] != "NULL" and row [0 ] in classification_name_to_class_id :
243
- avg_dictionary [classification_name_to_class_id [row [0 ]]] = row [1 ]
244
-
245
- return overall_average , classification_name_to_class_id , avg_dictionary
224
+ try :
225
+ # First build query string to get classification map
226
+ classes = ""
227
+ query = {
228
+ "userName" : user_name
229
+ }
230
+ function_url = function_url + "/api/classification"
231
+ overall_average = 0.0
232
+ with open (perf_location ) as f :
233
+ content = csv .reader (f , delimiter = ',' )
234
+ next (content , None ) #Skip header
235
+ for line in content :
236
+ class_name = line [0 ].strip ()
237
+ if class_name == "Average" :
238
+ overall_average = line [1 ]
239
+ elif class_name not in classes and class_name != "NULL" :
240
+ classes = classes + class_name + ","
241
+
242
+ query ["className" ] = classes [:- 1 ]
243
+ print ("Getting classification map for classes " + query ["className" ])
244
+ response = requests .get (function_url , params = query )
245
+ response .raise_for_status ()
246
+ classification_name_to_class_id = response .json ()
247
+
248
+ # Now that we have classification map, build the dictionary that maps class id : average
249
+ avg_dictionary = {}
250
+ with open (perf_location ) as csvfile :
251
+ reader = csv .reader (csvfile , delimiter = ',' )
252
+ next (reader , None ) #Skip header
253
+ for row in reader :
254
+ if row [0 ] != "NULL" and row [0 ] in classification_name_to_class_id :
255
+ avg_dictionary [classification_name_to_class_id [row [0 ]]] = row [1 ]
256
+
257
+ return overall_average , classification_name_to_class_id , avg_dictionary
258
+ except requests .exceptions .HTTPError as e :
259
+ print ("HTTP Error when getting classification map: {}" ,e .response .content )
260
+ raise
261
+ except Exception as e :
262
+ print ("Issue processing classfication: {}" ,e )
246
263
247
264
def get_image_name_from_url (image_url ):
248
265
start_idx = image_url .rfind ('/' )+ 1
@@ -269,8 +286,12 @@ def create_pascal_label_map(label_map_path: str, class_names: list):
269
286
if operation == "start" :
270
287
train (legacy_config , config .get ("tagging_user" ), config .get ("url" ))
271
288
elif operation == "save" :
272
- # Upload the model saved at ${inference_output_dir}/frozen_inference_graph.pb
289
+ # The model is saved relative to the python_file_directory in
290
+ # ${inference_output_dir}/frozen_inference_graph.pb
291
+ path_to_model = os .path .join (legacy_config .get ("python_file_directory" ),
292
+ legacy_config .get ("inference_output_dir" ),
293
+ "/frozen_inference_graph.pb" )
273
294
save_training_session (config ,
274
- legacy_config . get ( "inference_output_dir" ) + "/frozen_inference_graph.pb" ,
275
- legacy_config .get ("validation_output" ),
276
- legacy_config .get ("untagged_output" ))
295
+ path_to_model ,
296
+ legacy_config .get ("validation_output" ),
297
+ legacy_config .get ("untagged_output" ))
0 commit comments