Skip to content

Commit 738f9e2

Browse files
committed
updated.'
1 parent 148a2de commit 738f9e2

8 files changed

+147
-135
lines changed

dataset.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def __init__(self, proposals, frame_list, max_zero_weight=0.25):
490490

491491
if percent < max_zero_weight:
492492
# We don't care if there aren't many zeros.
493+
weights.append(1)
493494
continue
494495

495496
# Otherwise, we roughly want there to be 10% zeros at most.
@@ -525,8 +526,6 @@ def __init__(self, opt, subset="train"):
525526
def _exists(self, video_name):
526527
pgm_proposals_path = os.path.join(self.opt['pgm_proposals_dir'], '%s.proposals.csv' % video_name)
527528
pgm_features_path = os.path.join(self.opt['pgm_features_dir'], '%s.features.npy' % video_name)
528-
print(pgm_proposals_path)
529-
print(pgm_features_path)
530529
return os.path.exists(pgm_proposals_path) and os.path.exists(pgm_features_path)
531530

532531
def _getDatasetDict(self):
@@ -536,16 +535,18 @@ def _getDatasetDict(self):
536535
for i in range(len(anno_df)):
537536
video_name = anno_df.video.values[i]
538537
video_info = anno_database[video_name]
539-
video_subset = anno_df.subset.values[i]
538+
539+
if 'thumos' in self.opt['dataset']:
540+
video_subset = video_name.split('_')[1].replace('validation', 'train')
541+
else:
542+
video_subset = anno_df.subset.values[i]
543+
540544
if self.subset == "full":
541545
self.video_dict[video_name] = video_info
542546
if self.subset in video_subset:
543547
self.video_dict[video_name] = video_info
544548
self.video_list = sorted(self.video_dict.keys())
545549
self.video_list = [k for k in self.video_list if self._exists(k)]
546-
print('\n***\n')
547-
print(self.subset)
548-
print(self.video_list)
549550

550551
if self.opt['pem_do_index']:
551552
self.features = {}
@@ -554,15 +555,19 @@ def _getDatasetDict(self):
554555
for video_name in self.video_list:
555556
pgm_proposals_path = os.path.join(self.opt['pgm_proposals_dir'], '%s.proposals.csv' % video_name)
556557
pgm_features_path = os.path.join(self.opt['pgm_features_dir'], '%s.features.npy' % video_name)
557-
pdf = pd.read_csv(pgm_proposals_path)
558-
pdf = pdf.sort_values(by="score", ascending=False)
558+
pdf = pd.read_csv(pgm_proposals_path)
559559
video_feature = np.load(pgm_features_path)
560-
video_feature = video_feature[pdf[:self.top_K].index]
561560
pre_count = len(pdf)
562-
pdf = pdf[:self.top_K]
563-
print(video_name, pre_count, len(pdf), video_feature.shape)
564-
print('Num zeros in match_iou: ', len(pdf[pdf.match_iou == 0]))
565-
print('')
561+
if self.top_K is not None:
562+
try:
563+
pdf = pdf.sort_values(by="score", ascending=False)
564+
except KeyError:
565+
pdf['score'] = pdf.xmin_score * pdf.xmax_score
566+
pdf = pdf.sort_values(by="score", ascending=False)
567+
pdf = pdf[:self.top_K]
568+
video_feature = video_feature[pdf.index]
569+
570+
print(video_name, pre_count, len(pdf), video_feature.shape, pgm_proposals_path, pgm_features_path)
566571
self.proposals[video_name] = pdf
567572
self.features[video_name] = video_feature
568573
self.indices.extend([(video_name, i) for i in range(len(pdf))])
@@ -600,10 +605,11 @@ def __getitem__(self, index):
600605
# ***
601606
pdf = pdf.sort_values(by="score", ascending=False)
602607
# ***
603-
pdf = pdf[:self.top_K]
604-
605608
video_feature = np.load(pgm_features_path)
606-
video_feature = video_feature[:self.top_K, :]
609+
if self.top_K is not None:
610+
pdf = pdf[:self.top_K]
611+
video_feature = video_feature[:self.top_K, :]
612+
607613
video_feature = torch.Tensor(video_feature)
608614

609615
if self.mode == "train":

gen_pgm_results_jobs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,22 @@
2323

2424
for tem_results_subdir in os.listdir(tem_results_dir):
2525
counter = int(regex.match(tem_results_subdir).groups()[0])
26+
if counter in [301]:
27+
continue
28+
29+
print(tem_results_dir, tem_results_subdir, counter)
2630
job = run(find_counter=counter)
2731

2832
name = job['name']
2933
for ckpt_subdir in os.listdir(os.path.join(tem_results_dir, tem_results_subdir)):
3034
_job = deepcopy(job)
35+
if 'thumos' in _job['dataset']:
36+
_job['video_anno'] = os.path.join(_job['video_info'], 'thumos_anno_action.json')
3137
_job['num_gpus'] = 0
3238
_job['num_cpus'] = 48
3339
_job['pgm_thread'] = 40
3440
_job['gb'] = 64
35-
_job['time'] = 6 # what time should this be?
41+
_job['time'] = 5 # how long?
3642
_job['tem_results_dir'] = os.path.join(tem_results_dir, tem_results_subdir, ckpt_subdir)
3743

3844
propdir = os.path.join(pgm_proposals_dir, tem_results_subdir, ckpt_subdir)

gen_tem_results_jobs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
for ckpt_subdir in os.listdir(ckpt_directory):
3030
counter = int(regex.match(ckpt_subdir).groups()[0])
31+
if counter not in [195]:
32+
continue
33+
3134
_job = run(find_counter=counter)
3235
_job['num_gpus'] = 8
3336
_job['num_cpus'] = 8 * 10
@@ -40,6 +43,11 @@
4043
_job['tem_results_subset'] = 'full'
4144
name = _job['name']
4245
for ckpt_epoch in [5, 15, 20]:
46+
if counter == 195:
47+
if ckpt_epoch < 20:
48+
continue
49+
ckpt_epoch = 19
50+
4351
_job['checkpoint_epoch'] = ckpt_epoch
4452
_job['name'] = '%s.ckpt%d' % (name, ckpt_epoch)
4553
print(ckpt_subdir, counter)

loss_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def PEM_loss_function(anchors_iou, match_iou, opt):
9090

9191
iou_weights = u_hmask + u_smmask + u_slmask
9292
iou_loss = F.smooth_l1_loss(anchors_iou, match_iou.squeeze())
93-
print('LOSS')
94-
print(iou_loss.shape, iou_weights.shape)
95-
print(iou_weights)
93+
# print('LOSS')
94+
# print(iou_loss.shape, iou_weights.shape)
95+
# print(iou_weights)
9696
iou_loss = torch.sum(iou_loss * iou_weights) / (1e-6 + torch.sum(iou_weights))
9797

9898
return {'iou_loss': iou_loss}

main.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def train_TEM(data_loader, model, optimizer, epoch, global_step, comet_exp, opt)
8989

9090
print('Count: ', count)
9191
epoch_sums, epoch_avg = compute_metrics(epoch_sums, loss, count)
92-
epoch_avg['current_l2'] = sum([W.norm(2) for W in model.module.parameters()])
92+
epoch_avg['current_l2'] = sum([W.norm(2) for W in model.module.parameters()]).cpu().detach().numpy()
9393
steps_per_second = (n_iter+1) / (time.time() - start)
9494
epoch_avg['steps_per_second'] = steps_per_second
9595
print('\n***End of Epoch %d***\nS/S %.3f, Global Step %d, Local Step %d / %d.' % (epoch, steps_per_second, global_step, n_iter, len(data_loader)))
@@ -241,7 +241,6 @@ def test_PEM(data_loader, model, epoch, global_step, comet_exp, opt):
241241
def BSN_Train_TEM(opt):
242242
if opt['do_representation']:
243243
model = TEM(opt)
244-
img_loading_func = get_img_loader(opt)
245244
partial_load(opt['representation_checkpoint'], model)
246245
for param in model.representation_model.parameters():
247246
param.requires_grad = False
@@ -258,6 +257,7 @@ def BSN_Train_TEM(opt):
258257
weight_decay=opt["tem_weight_decay"])
259258

260259
if opt['dataset'] == 'gymnastics':
260+
img_loading_func = get_img_loader(opt)
261261
train_data_set = GymnasticsDataSet(opt, subset=opt['tem_train_subset'], img_loading_func=img_loading_func, overlap_windows=True)
262262
train_sampler = GymnasticsSampler(train_data_set.video_dict, train_data_set.frame_list)
263263
test_data_set = GymnasticsDataSet(opt, subset="test", img_loading_func=img_loading_func)
@@ -323,10 +323,10 @@ def BSN_Train_TEM(opt):
323323
comet_exp.set_name(opt['name'])
324324

325325
for epoch in range(opt["tem_epoch"]):
326-
# test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
327326
global_step = train_TEM(train_loader, model, optimizer, epoch, global_step, comet_exp, opt)
328327
scheduler.step()
329-
test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
328+
test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
329+
# test_TEM(test_loader, model, epoch, global_step, comet_exp, opt)
330330

331331

332332
def BSN_Train_PEM(opt):
@@ -344,7 +344,6 @@ def collate_fn(batch):
344344
batch_iou = torch.cat([x[1] for x in batch])
345345
return batch_data, batch_iou
346346

347-
348347
train_dataset = ProposalDataSet(opt, subset="train")
349348
train_sampler = ProposalSampler(train_dataset.proposals, train_dataset.indices, max_zero_weight=opt['pem_max_zero_weight'])
350349

@@ -396,17 +395,17 @@ def collate_fn(batch):
396395
comet_exp.set_name(opt['name'])
397396

398397
for epoch in range(opt["pem_epoch"]):
398+
test_PEM(test_loader, model, epoch, global_step, comet_exp, opt)
399399
global_step = train_PEM(train_loader, model, optimizer, epoch, global_step, comet_exp, opt)
400400
scheduler.step()
401-
test_PEM(test_loader, model, epoch, global_step, comet_exp, opt)
401+
test_PEM(test_loader, model, epoch, global_step, comet_exp, opt)
402402

403403

404404
def BSN_inference_TEM(opt):
405405
output_dir = os.path.join(opt['tem_results_dir'], opt['checkpoint_path'].split('/')[-1])
406406
print(sorted(opt.items()))
407407

408408
model = TEM(opt)
409-
img_loading_func = get_img_loader(opt)
410409
checkpoint_epoch = opt['checkpoint_epoch']
411410
if checkpoint_epoch is not None:
412411
checkpoint_path = os.path.join(opt['checkpoint_path'], 'tem_checkpoint.%d.pth' % checkpoint_epoch)
@@ -429,6 +428,7 @@ def BSN_inference_TEM(opt):
429428
model.eval()
430429

431430
if opt['dataset'] == 'gymnastics':
431+
img_loading_func = get_img_loader(opt)
432432
dataset = GymnasticsDataSet(opt, subset=opt['tem_results_subset'], img_loading_func=img_loading_func)
433433
elif opt['dataset'] == 'thumosfeatures':
434434
feature_dirs = opt['feature_dirs'].split(',')
@@ -488,11 +488,8 @@ def BSN_inference_TEM(opt):
488488
current_data[0].extend(batch_action[batch_idx])
489489
current_data[1].extend(batch_start[batch_idx])
490490
current_data[2].extend(batch_end[batch_idx])
491-
# NOTE: Will this work ?
492-
###
493491
current_data[3].extend(anchor_xmin)
494492
current_data[4].extend(anchor_xmax)
495-
###
496493
current_data[5].extend(list(frames))
497494
else:
498495
batch_video = video[batch_idx]

0 commit comments

Comments
 (0)