Skip to content

Commit 18976b8

Browse files
authored
Fixed query issue when saving Training session. More exception handling (#101)
* Fixed query issue when saving Training session. More exception handling * Split lines * Tabbing
1 parent 8220a68 commit 18976b8

File tree

3 files changed

+83
-61
lines changed

3 files changed

+83
-61
lines changed

functions/pipeline/shared/db_access/db_access_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def add_training_session(self, training: TrainingSession, user_id: int):
427427
"VALUES ('{}','{}',{},{}) RETURNING TrainingId), "
428428
"p AS (INSERT INTO Class_Performance (TrainingId,ClassificationId,AvgPerf) "
429429
"VALUES ")
430-
query.format(training.description,training.model_url,training.avg_perf,user_id)
430+
query = query.format(training.description,training.model_url,training.avg_perf,user_id)
431431

432432
# Append multiple TrainingId, ClassificationId and Performance values to above query
433433
# Comma is more rows, closing parenthesis is on the last row

functions/pipeline/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
3939
)
4040
elif req.method == "POST":
4141
payload = json.loads(req.get_body())
42+
logging.debug("Payload: {}".format(payload))
4243
payload_json = namedtuple('TrainingSession', payload.keys())(*payload.values())
4344
training_id = data_access.add_training_session(payload_json, user_id)
4445
return func.HttpResponse(

train_vnext/training.py

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -183,66 +183,83 @@ def save_training_session(config, model_location, perf_location, prediction_labe
183183
upload_data_post_training(prediction_labels_location, classification_name_to_class_id, training_id, config.get("tagging_user"), config.get("url"))
184184

185185
def upload_model_to_blob_storage(config, model_location, file_name, user_name):
186-
blob_storage = BlobStorage.get_azure_storage_client(config)
187-
blob_metadata = {
188-
"userFilePath": model_location,
189-
"uploadUser": user_name
190-
}
191-
uri = 'https://' + config.get("storage_account") + '.blob.core.windows.net/' + config.get("storage_container") + '/' + file_name
192-
blob_storage.create_blob_from_path(
193-
config.get("storage_container"),
194-
file_name,
195-
model_location,
196-
metadata=blob_metadata
197-
)
198-
print("Model uploaded at " + str(uri))
199-
return uri
186+
try:
187+
blob_storage = BlobStorage.get_azure_storage_client(config)
188+
blob_metadata = {
189+
"userFilePath": model_location,
190+
"uploadUser": user_name
191+
}
192+
uri = 'https://' + config.get("storage_account") + '.blob.core.windows.net/' + config.get("storage_container") + '/' + file_name
193+
blob_storage.create_blob_from_path(
194+
config.get("storage_container"),
195+
file_name,
196+
model_location,
197+
metadata=blob_metadata
198+
)
199+
print("Model uploaded at " + str(uri))
200+
return uri
201+
except Exception as e:
202+
print("Issue uploading model to cloud storage: {}",e)
200203

201204
def construct_new_training_session(perf_location, classification_name_to_class_id, overall_average, training_description, model_location, avg_dictionary, user_name, function_url):
202-
training_session = TrainingSession(training_description, model_location, overall_average, avg_dictionary)
203-
query = {
204-
"userName": user_name
205-
}
206-
function_url = function_url + "/api/train"
207-
payload = jsonpickle.encode(training_session, unpicklable=False)
208-
response = requests.post(function_url, params=query, json=payload)
209-
training_id = int(response.json())
210-
print("Created a new training session with id: " + str(training_id))
211-
return training_id
205+
try:
206+
training_session = TrainingSession(training_description, model_location, overall_average, avg_dictionary)
207+
query = {
208+
"userName": user_name
209+
}
210+
function_url = function_url + "/api/train"
211+
payload = jsonpickle.encode(training_session, unpicklable=False)
212+
response = requests.post(function_url, params=query, json=payload)
213+
response.raise_for_status()
214+
training_id = int(response.json())
215+
print("Created a new training session with id: " + str(training_id))
216+
return training_id
217+
except requests.exceptions.HTTPError as e:
218+
print("HTTP Error when saving training session: {}",e.response.content)
219+
raise
220+
except Exception as e:
221+
print("Issue saving training session: {}",e)
212222

213223
def process_classifications(perf_location, user_name,function_url):
214-
# First build query string to get classification map
215-
classes = ""
216-
query = {
217-
"userName": user_name
218-
}
219-
function_url = function_url + "/api/classification"
220-
overall_average = 0.0
221-
with open(perf_location) as f:
222-
content = csv.reader(f, delimiter=',')
223-
next(content, None) #Skip header
224-
for line in content:
225-
class_name = line[0].strip()
226-
if class_name == "Average":
227-
overall_average = line[1]
228-
elif class_name not in classes and class_name != "NULL":
229-
classes = classes + class_name + ","
230-
231-
query["className"] = classes[:-1]
232-
print("Getting classification map for classes " + query["className"])
233-
response = requests.get(function_url, params=query)
234-
classification_name_to_class_id = response.json()
235-
236-
# Now that we have classification map, build the dictionary that maps class id : average
237-
avg_dictionary = {}
238-
with open(perf_location) as csvfile:
239-
reader = csv.reader(csvfile, delimiter=',')
240-
next(reader, None) #Skip header
241-
for row in reader:
242-
if row[0] != "NULL" and row[0] in classification_name_to_class_id:
243-
avg_dictionary[classification_name_to_class_id[row[0]]] = row[1]
244-
245-
return overall_average, classification_name_to_class_id, avg_dictionary
224+
try:
225+
# First build query string to get classification map
226+
classes = ""
227+
query = {
228+
"userName": user_name
229+
}
230+
function_url = function_url + "/api/classification"
231+
overall_average = 0.0
232+
with open(perf_location) as f:
233+
content = csv.reader(f, delimiter=',')
234+
next(content, None) #Skip header
235+
for line in content:
236+
class_name = line[0].strip()
237+
if class_name == "Average":
238+
overall_average = line[1]
239+
elif class_name not in classes and class_name != "NULL":
240+
classes = classes + class_name + ","
241+
242+
query["className"] = classes[:-1]
243+
print("Getting classification map for classes " + query["className"])
244+
response = requests.get(function_url, params=query)
245+
response.raise_for_status()
246+
classification_name_to_class_id = response.json()
247+
248+
# Now that we have classification map, build the dictionary that maps class id : average
249+
avg_dictionary = {}
250+
with open(perf_location) as csvfile:
251+
reader = csv.reader(csvfile, delimiter=',')
252+
next(reader, None) #Skip header
253+
for row in reader:
254+
if row[0] != "NULL" and row[0] in classification_name_to_class_id:
255+
avg_dictionary[classification_name_to_class_id[row[0]]] = row[1]
256+
257+
return overall_average, classification_name_to_class_id, avg_dictionary
258+
except requests.exceptions.HTTPError as e:
259+
print("HTTP Error when getting classification map: {}",e.response.content)
260+
raise
261+
except Exception as e:
262+
print("Issue processing classfication: {}",e)
246263

247264
def get_image_name_from_url(image_url):
248265
start_idx = image_url.rfind('/')+1
@@ -269,8 +286,12 @@ def create_pascal_label_map(label_map_path: str, class_names: list):
269286
if operation == "start":
270287
train(legacy_config, config.get("tagging_user"), config.get("url"))
271288
elif operation == "save":
272-
# Upload the model saved at ${inference_output_dir}/frozen_inference_graph.pb
289+
# The model is saved relative to the python_file_directory in
290+
# ${inference_output_dir}/frozen_inference_graph.pb
291+
path_to_model = os.path.join(legacy_config.get("python_file_directory"),
292+
legacy_config.get("inference_output_dir"),
293+
"/frozen_inference_graph.pb")
273294
save_training_session(config,
274-
legacy_config.get("inference_output_dir") + "/frozen_inference_graph.pb",
275-
legacy_config.get("validation_output"),
276-
legacy_config.get("untagged_output"))
295+
path_to_model,
296+
legacy_config.get("validation_output"),
297+
legacy_config.get("untagged_output"))

0 commit comments

Comments
 (0)