@@ -89,7 +89,7 @@ def train_TEM(data_loader, model, optimizer, epoch, global_step, comet_exp, opt)
89
89
90
90
print ('Count: ' , count )
91
91
epoch_sums , epoch_avg = compute_metrics (epoch_sums , loss , count )
92
- epoch_avg ['current_l2' ] = sum ([W .norm (2 ) for W in model .module .parameters ()])
92
+ epoch_avg ['current_l2' ] = sum ([W .norm (2 ) for W in model .module .parameters ()]). cpu (). detach (). numpy ()
93
93
steps_per_second = (n_iter + 1 ) / (time .time () - start )
94
94
epoch_avg ['steps_per_second' ] = steps_per_second
95
95
print ('\n ***End of Epoch %d***\n S/S %.3f, Global Step %d, Local Step %d / %d.' % (epoch , steps_per_second , global_step , n_iter , len (data_loader )))
@@ -241,7 +241,6 @@ def test_PEM(data_loader, model, epoch, global_step, comet_exp, opt):
241
241
def BSN_Train_TEM (opt ):
242
242
if opt ['do_representation' ]:
243
243
model = TEM (opt )
244
- img_loading_func = get_img_loader (opt )
245
244
partial_load (opt ['representation_checkpoint' ], model )
246
245
for param in model .representation_model .parameters ():
247
246
param .requires_grad = False
@@ -258,6 +257,7 @@ def BSN_Train_TEM(opt):
258
257
weight_decay = opt ["tem_weight_decay" ])
259
258
260
259
if opt ['dataset' ] == 'gymnastics' :
260
+ img_loading_func = get_img_loader (opt )
261
261
train_data_set = GymnasticsDataSet (opt , subset = opt ['tem_train_subset' ], img_loading_func = img_loading_func , overlap_windows = True )
262
262
train_sampler = GymnasticsSampler (train_data_set .video_dict , train_data_set .frame_list )
263
263
test_data_set = GymnasticsDataSet (opt , subset = "test" , img_loading_func = img_loading_func )
@@ -323,10 +323,10 @@ def BSN_Train_TEM(opt):
323
323
comet_exp .set_name (opt ['name' ])
324
324
325
325
for epoch in range (opt ["tem_epoch" ]):
326
- # test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
327
326
global_step = train_TEM (train_loader , model , optimizer , epoch , global_step , comet_exp , opt )
328
327
scheduler .step ()
329
- test_TEM (test_loader , model , epoch , global_step , comet_exp , opt )
328
+ test_TEM (test_loader , model , epoch , global_step , comet_exp , opt )
329
+ # test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
330
330
331
331
332
332
def BSN_Train_PEM (opt ):
@@ -344,7 +344,6 @@ def collate_fn(batch):
344
344
batch_iou = torch .cat ([x [1 ] for x in batch ])
345
345
return batch_data , batch_iou
346
346
347
-
348
347
train_dataset = ProposalDataSet (opt , subset = "train" )
349
348
train_sampler = ProposalSampler (train_dataset .proposals , train_dataset .indices , max_zero_weight = opt ['pem_max_zero_weight' ])
350
349
@@ -396,17 +395,17 @@ def collate_fn(batch):
396
395
comet_exp .set_name (opt ['name' ])
397
396
398
397
for epoch in range (opt ["pem_epoch" ]):
398
+ test_PEM (test_loader , model , epoch , global_step , comet_exp , opt )
399
399
global_step = train_PEM (train_loader , model , optimizer , epoch , global_step , comet_exp , opt )
400
400
scheduler .step ()
401
- test_PEM (test_loader , model , epoch , global_step , comet_exp , opt )
401
+ test_PEM (test_loader , model , epoch , global_step , comet_exp , opt )
402
402
403
403
404
404
def BSN_inference_TEM (opt ):
405
405
output_dir = os .path .join (opt ['tem_results_dir' ], opt ['checkpoint_path' ].split ('/' )[- 1 ])
406
406
print (sorted (opt .items ()))
407
407
408
408
model = TEM (opt )
409
- img_loading_func = get_img_loader (opt )
410
409
checkpoint_epoch = opt ['checkpoint_epoch' ]
411
410
if checkpoint_epoch is not None :
412
411
checkpoint_path = os .path .join (opt ['checkpoint_path' ], 'tem_checkpoint.%d.pth' % checkpoint_epoch )
@@ -429,6 +428,7 @@ def BSN_inference_TEM(opt):
429
428
model .eval ()
430
429
431
430
if opt ['dataset' ] == 'gymnastics' :
431
+ img_loading_func = get_img_loader (opt )
432
432
dataset = GymnasticsDataSet (opt , subset = opt ['tem_results_subset' ], img_loading_func = img_loading_func )
433
433
elif opt ['dataset' ] == 'thumosfeatures' :
434
434
feature_dirs = opt ['feature_dirs' ].split (',' )
@@ -488,11 +488,8 @@ def BSN_inference_TEM(opt):
488
488
current_data [0 ].extend (batch_action [batch_idx ])
489
489
current_data [1 ].extend (batch_start [batch_idx ])
490
490
current_data [2 ].extend (batch_end [batch_idx ])
491
- # NOTE: Will this work ?
492
- ###
493
491
current_data [3 ].extend (anchor_xmin )
494
492
current_data [4 ].extend (anchor_xmax )
495
- ###
496
493
current_data [5 ].extend (list (frames ))
497
494
else :
498
495
batch_video = video [batch_idx ]
0 commit comments