3
3
import json
4
4
import os
5
5
from pathlib import Path
6
+ import random
6
7
import re
7
8
import pickle
8
9
from urllib .parse import unquote
11
12
import pandas as pd
12
13
import torch .utils .data as data
13
14
import torch
15
+ from torchvision .datasets .video_utils import VideoClips
14
16
15
17
16
18
def load_json (file ):
@@ -24,6 +26,7 @@ def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
24
26
int_xmin = np .maximum (anchors_min , box_min )
25
27
int_xmax = np .minimum (anchors_max , box_max )
26
28
inter_len = np .maximum (int_xmax - int_xmin , 0. )
29
+ # print(anchors_min, anchors_max, box_min, box_max, int_xmin, int_xmax, inter_len)
27
30
scores = np .divide (inter_len , len_anchors )
28
31
return scores
29
32
@@ -47,7 +50,7 @@ def __init__(self, opt, subset=None, feature_dirs=[], fps=30, image_dir=None, im
47
50
self .fps = fps
48
51
49
52
# e.g. /data/thumos14_annotations/Test_Annotation.csv
50
- self .video_info_path = os . path . join ( opt ["video_info" ], '%s_Annotation.csv' % self . subset )
53
+ self .video_info_path = opt ["video_info" ]
51
54
self ._get_data ()
52
55
53
56
def _get_data (self ):
@@ -325,6 +328,198 @@ def _get_image_dir(self, video_name):
325
328
return os .path .join (self .image_dir , target_dir )
326
329
327
330
331
+ class VideoDataset (data .Dataset ):
332
+ def __init__ (self , opt , transforms , subset , fraction = 1. ):
333
+ """file_list is a list of [/path/to/mp4 key-to-df]"""
334
+ self .subset = subset
335
+ self .video_info_path = opt ["video_info" ]
336
+ self .mode = opt ["mode" ]
337
+ self .boundary_ratio = opt ['boundary_ratio' ]
338
+ self .skip_videoframes = opt ['skip_videoframes' ]
339
+ self .num_videoframes = opt ['num_videoframes' ]
340
+ self .dist_videoframes = opt ['dist_videoframes' ]
341
+ self .fraction = fraction
342
+
343
+ subset_translate = {'train' : 'training' , 'val' : 'validation' }
344
+ self .anno_df = pd .read_csv (self .video_info_path )
345
+ print (self .anno_df )
346
+ print (subset , subset_translate .get (subset ))
347
+ self .anno_df = self .anno_df [self .anno_df .subset == subset_translate [subset ]]
348
+ print (self .anno_df )
349
+
350
+ file_loc = opt ['%s_video_file_list' % subset ]
351
+ with open (file_loc , 'r' ) as f :
352
+ lines = [k .strip () for k in f .readlines ()]
353
+
354
+ file_list = [k .split (' ' )[0 ] for k in lines ]
355
+ keys_list = [k .split (' ' )[1 ][:- 4 ] for k in lines ]
356
+ print (keys_list [:5 ])
357
+ valid_key_indices = [num for num , k in enumerate (keys_list ) \
358
+ if k in set (self .anno_df .video .unique ())]
359
+ self .keys_list = [keys_list [num ] for num in valid_key_indices ]
360
+ self .file_list = [file_list [num ] for num in valid_key_indices ]
361
+ print ('Number of indices: ' , len (valid_key_indices ), subset )
362
+
363
+ video_info_dir = '/' .join (self .video_info_path .split ('/' )[:- 1 ])
364
+ clip_length_in_frames = self .num_videoframes * self .skip_videoframes
365
+ frames_between_clips = self .dist_videoframes
366
+ saved_video_clips = os .path .join (
367
+ video_info_dir , 'video_clips.%s.%df.%ds.pkl' % (
368
+ subset , clip_length_in_frames , frames_between_clips ))
369
+ if os .path .exists (saved_video_clips ):
370
+ print ('Path Exists for video_clips: ' , saved_video_clips )
371
+ self .video_clips = pickle .load (open (saved_video_clips , 'rb' ))
372
+ else :
373
+ print ('Path does NOT exist for video_clips: ' , saved_video_clips )
374
+ self .video_clips = VideoClips (
375
+ self .file_list , clip_length_in_frames = clip_length_in_frames ,
376
+ frames_between_clips = frames_between_clips , frame_rate = opt ['fps' ])
377
+ pickle .dump (self .video_clips , open (saved_video_clips , 'wb' ))
378
+ print ('Length of vid clips: ' , self .video_clips .num_clips (), self .subset )
379
+
380
+ if self .mode == "train" :
381
+ self .datums = self ._retrieve_valid_datums ()
382
+ self .datum_indices = list (range (len (self .datums )))
383
+ if fraction < 1 :
384
+ print ('DOING the subset dataset on %s ...' % subset )
385
+ self ._subset_dataset (fraction )
386
+ print ('Len of %s datums: ' % subset , len (self .datum_indices ))
387
+
388
+ self .transforms = transforms
389
+
390
+ def _subset_dataset (self , fraction ):
391
+ num_datums = int (len (self .datums ) * fraction )
392
+ self .datum_indices = list (range (num_datums ))
393
+ random .shuffle (self .datum_indices )
394
+ self .datum_indices = self .datum_indices [:num_datums ]
395
+
396
+ def __len__ (self ):
397
+ return len (self .datum_indices )
398
+
399
+ def _retrieve_valid_datums (self ):
400
+ video_info_dir = '/' .join (self .video_info_path .split ('/' )[:- 1 ])
401
+ num_clips = self .video_clips .num_clips ()
402
+ saved_data_path = os .path .join (video_info_dir , 'saved.%s.nf%d.sf%d.df%d.vid%d.pkl' % (
403
+ self .subset , self .num_videoframes , self .skip_videoframes , self .dist_videoframes ,
404
+ num_clips
405
+ )
406
+ )
407
+ print (saved_data_path )
408
+ if os .path .exists (saved_data_path ):
409
+ print ('Got saved data.' )
410
+ with open (saved_data_path , 'rb' ) as f :
411
+ return pickle .load (f )
412
+
413
+ ret = []
414
+ for flat_index in range (num_clips ):
415
+ video_idx , clip_idx = self .video_clips .get_clip_location (flat_index )
416
+ start_frame = clip_idx * self .dist_videoframes
417
+ snippets = [start_frame + self .skip_videoframes * i
418
+ for i in range (self .num_videoframes )]
419
+ key = self .keys_list [video_idx ]
420
+ training_anchors = self ._get_training_anchors (snippets , key )
421
+ if not training_anchors :
422
+ continue
423
+
424
+ anchor_xmins , anchor_xmaxs , gt_bbox = training_anchors
425
+ ret .append ((flat_index , anchor_xmins , anchor_xmaxs , gt_bbox ))
426
+
427
+ print ('Size of data: ' , len (ret ), flush = True )
428
+ with open (saved_data_path , 'wb' ) as f :
429
+ pickle .dump (ret , f )
430
+ print ('Dumped data...' )
431
+ return ret
432
+
433
+ def __getitem__ (self , index ):
434
+ # The video_data retrieved has shape [nf * sf, w, h, c].
435
+ # We want to pick every sf'th frame out of that.
436
+ if self .mode == "train" :
437
+ datum_index = self .datum_indices [index ]
438
+ flat_index , anchor_xmin , anchor_xmax , gt_bbox = self .datums [datum_index ]
439
+ video , _ , _ , video_idx = self .video_clips .get_clip (flat_index )
440
+
441
+ video_data = video [0 ::self .skip_videoframes ]
442
+ video_data = self .transforms (video_data )
443
+ video_data = torch .transpose (video_data , 0 , 1 )
444
+
445
+ _ , clip_idx = self .video_clips .get_clip_location (index )
446
+ start_frame = clip_idx * self .dist_videoframes
447
+ snippets = [start_frame + self .skip_videoframes * i
448
+ for i in range (self .num_videoframes )]
449
+ if self .mode == "train" :
450
+ match_score_action , match_score_start , match_score_end = self ._get_train_label (gt_bbox , anchor_xmin , anchor_xmax )
451
+ return video_data , match_score_action , match_score_start , match_score_end
452
+ else :
453
+ video_name = self .keys_list [video_idx ]
454
+ return flat_index , video_data , video_name , snippets
455
+
456
+ def _get_training_anchors (self , snippets , key ):
457
+ tmp_anchor_xmins = np .array (snippets ) - self .skip_videoframes / 2.
458
+ tmp_anchor_xmaxs = np .array (snippets ) + self .skip_videoframes / 2.
459
+ tmp_gt_bbox = []
460
+ tmp_ioa_list = []
461
+ anno_df_video = self .anno_df [self .anno_df .video == key ]
462
+ gt_xmins = anno_df_video .startFrame .values [:]
463
+ gt_xmaxs = anno_df_video .endFrame .values [:]
464
+ if len (gt_xmins ) == 0 :
465
+ print ('Yo wat gt_xmins: ' , key )
466
+ raise
467
+
468
+ for idx in range (len (gt_xmins )):
469
+ tmp_ioa = ioa_with_anchors (gt_xmins [idx ], gt_xmaxs [idx ],
470
+ tmp_anchor_xmins [0 ],
471
+ tmp_anchor_xmaxs [- 1 ])
472
+ tmp_ioa_list .append (tmp_ioa )
473
+ if tmp_ioa > 0 :
474
+ tmp_gt_bbox .append ([gt_xmins [idx ], gt_xmaxs [idx ]])
475
+
476
+ # print(len(tmp_gt_bbox), max(tmp_ioa_list), tmp_ioa_list)
477
+ if len (tmp_gt_bbox ) > 0 :
478
+ # NOTE: Removed the threshold of 0.9... ruh roh.
479
+ return tmp_anchor_xmins , tmp_anchor_xmaxs , tmp_gt_bbox
480
+ return None
481
+
482
+ def _get_train_label (self , gt_bbox , anchor_xmin , anchor_xmax ):
483
+ gt_bbox = np .array (gt_bbox )
484
+ gt_xmins = gt_bbox [:, 0 ]
485
+ gt_xmaxs = gt_bbox [:, 1 ]
486
+ # same as gt_len but using the thumos code repo :/.
487
+ gt_duration = gt_xmaxs - gt_xmins
488
+ gt_duration_boundary = np .maximum (
489
+ self .skip_videoframes , gt_duration * self .boundary_ratio )
490
+ gt_start_bboxs = np .stack (
491
+ (gt_xmins - gt_duration_boundary / 2 , gt_xmins + gt_duration_boundary / 2 ),
492
+ axis = 1
493
+ )
494
+ gt_end_bboxs = np .stack (
495
+ (gt_xmaxs - gt_duration_boundary / 2 , gt_xmaxs + gt_duration_boundary / 2 ),
496
+ axis = 1
497
+ )
498
+
499
+ match_score_action = [
500
+ np .max (
501
+ ioa_with_anchors (anchor_xmin [jdx ], anchor_xmax [jdx ],
502
+ gt_xmins , gt_xmaxs ))
503
+ for jdx in range (len (anchor_xmin ))
504
+ ]
505
+
506
+ match_score_start = [
507
+ np .max (
508
+ ioa_with_anchors (anchor_xmin [jdx ], anchor_xmax [jdx ],
509
+ gt_start_bboxs [:, 0 ], gt_start_bboxs [:, 1 ]))
510
+ for jdx in range (len (anchor_xmin ))
511
+ ]
512
+
513
+ match_score_end = [
514
+ np .max (
515
+ ioa_with_anchors (anchor_xmin [jdx ], anchor_xmax [jdx ],
516
+ gt_end_bboxs [:, 0 ], gt_end_bboxs [:, 1 ]))
517
+ for jdx in range (len (anchor_xmin ))
518
+ ]
519
+
520
+ return torch .Tensor (match_score_action ), torch .Tensor (match_score_start ), torch .Tensor (match_score_end )
521
+
522
+
328
523
class ProposalSampler (data .WeightedRandomSampler ):
329
524
def __init__ (self , proposals , frame_list , max_zero_weight = 0.25 ):
330
525
"""
@@ -558,8 +753,3 @@ def make_on_anno_files(mmd, videotable):
558
753
current_end = e
559
754
wv ['annotations' ].append ({'label' : 'on' , 'segment' : [current_start , current_end ]})
560
755
onmmd_anno [k ] = wv
561
-
562
-
563
-
564
-
565
-
0 commit comments