Skip to content

Commit dc97f0a

Browse files
committed
[#6] Can handle any type of annotation file.
1 parent a5f8183 commit dc97f0a

File tree

3 files changed

+48
-45
lines changed

3 files changed

+48
-45
lines changed

Diff for: dataset/dataset.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class DataSet(ABC):
10-
def __init__(self, base_directory: str, target_extension: Union[str, List[str]]):
10+
def __init__(self, base_directory: str, target_extension: Union[str, List[str]], anntoation_extension: Union[None, str]=None):
1111
self.base_directory = base_directory
1212

1313
if type(target_extension) == list:
@@ -17,6 +17,11 @@ def __init__(self, base_directory: str, target_extension: Union[str, List[str]])
1717
else:
1818
raise ValueError("The target_extension should be the string of extension or list of it.")
1919

20+
if anntoation_extension is not None:
21+
self.anntoation_extension = anntoation_extension
22+
else:
23+
self.anntoation_extension = "json"
24+
2025
self.train_pairs = []
2126
self.validation_pairs = []
2227

@@ -41,8 +46,8 @@ def prepare_pairs(self):
4146
train_file_list.sort()
4247
validation_file_list.sort()
4348

44-
self.train_pairs = DataSet.create_pairs_with_json(train_file_list)
45-
self.validation_pairs = DataSet.create_pairs_with_json(validation_file_list)
49+
self.train_pairs = DataSet.create_pairs_with_annotation(train_file_list, self.anntoation_extension)
50+
self.validation_pairs = DataSet.create_pairs_with_annotation(validation_file_list, self.anntoation_extension)
4651

4752
def set_train_mode(self):
4853
self.is_train_mode = True
@@ -75,14 +80,14 @@ def set_random_salt(self, random_salt):
7580
self.random_salt = 20200305 + random_salt
7681

7782
@staticmethod
78-
def create_pairs_with_json(file_list: list):
83+
def create_pairs_with_annotation(file_list: list, annotatino_extension: str):
7984
pairs = []
8085

8186
for file_path in file_list:
82-
json_file_path = os.path.splitext(file_path)[0] + ".json"
87+
annotation_file_path = os.path.splitext(file_path)[0] + ".%s" % annotatino_extension
8388

84-
if os.path.exists(json_file_path) is True:
85-
pairs.append((file_path, json_file_path))
89+
if os.path.exists(annotation_file_path) is True:
90+
pairs.append((file_path, annotation_file_path))
8691
else:
8792
continue
8893

@@ -131,6 +136,11 @@ def validation_datum_filter(file_path: str):
131136
def extract_label(file_path: str):
132137
pass
133138

139+
@staticmethod
140+
@abstractmethod
141+
def parse_annotation(file_path: str):
142+
pass
143+
134144
@staticmethod
135145
@abstractmethod
136146
def is_valid_annotation(image_width, image_height, annotation):

Diff for: dataset/image_dataset.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,22 @@ def create_valid_indices(self, pairs: list):
1616
description_prefix = "Checking validity: "
1717

1818
tqdm_iterator = tqdm.tqdm(pairs, desc=description_prefix)
19-
for index, (image_file_path, json_file_path) in enumerate(tqdm_iterator):
19+
for index, (image_file_path, annotation_file_path) in enumerate(tqdm_iterator):
2020
tqdm_iterator.set_description(description_prefix + image_file_path)
2121

2222
label = self.extract_label(image_file_path)
2323

2424
image_width, image_height = imagesize.get(image_file_path)
2525

26-
with open(json_file_path) as file:
27-
annotation = json.load(file)
26+
annotation = self.parse_annotation(annotation_file_path)
2827

29-
if self.is_valid_annotation(image_width, image_height, annotation) is False:
30-
continue
28+
if self.is_valid_annotation(image_width, image_height, annotation) is False:
29+
continue
3130

32-
if label not in valid_indices.keys():
33-
valid_indices[label] = []
31+
if label not in valid_indices.keys():
32+
valid_indices[label] = []
3433

35-
valid_indices[label].append(index)
34+
valid_indices[label].append(index)
3635

3736
return valid_indices
3837

@@ -44,23 +43,22 @@ def validation_count(self, label):
4443

4544
def get_train_filename(self, label: Union[int, str], index: int):
4645
pair_index = self.train_valid_indices[label][index]
47-
image_file_path, json_file_path = self.train_pairs[pair_index]
46+
image_file_path, annotation_file_path = self.train_pairs[pair_index]
4847

49-
return image_file_path, json_file_path
48+
return image_file_path, annotation_file_path
5049

5150
def get_validation_filename(self, label: Union[int, str], index: int):
5251
pair_index = self.validation_valid_indices[label][index]
53-
image_file_path, json_file_path = self.validation_pairs[pair_index]
52+
image_file_path, annotation_file_path = self.validation_pairs[pair_index]
5453

55-
return image_file_path, json_file_path
54+
return image_file_path, annotation_file_path
5655

5756
def get_train_datum(self, label, index):
5857
pair_index = self.train_valid_indices[label][index]
5958

60-
image_file_path, json_file_path = self.train_pairs[pair_index]
59+
image_file_path, annotation_file_path = self.train_pairs[pair_index]
6160

62-
with open(json_file_path) as json_file:
63-
annotation = json.load(json_file)
61+
annotation = self.parse_annotation(annotation_file_path)
6462

6563
image = cv2.imread(image_file_path)
6664

@@ -69,10 +67,9 @@ def get_train_datum(self, label, index):
6967
def get_validation_datum(self, label, index):
7068
pair_index = self.validation_valid_indices[label][index]
7169

72-
image_file_path, json_file_path = self.validation_pairs[pair_index]
70+
image_file_path, annotation_file_path = self.validation_pairs[pair_index]
7371

74-
with open(json_file_path) as json_file:
75-
annotation = json.load(json_file)
72+
annotation = self.parse_annotation(annotation_file_path)
7673

7774
image = cv2.imread(image_file_path)
7875

Diff for: dataset/movie_dataset.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Union
66

77
import cv2
8-
import json
98
import tqdm
109

1110

@@ -16,7 +15,7 @@ def create_valid_indices(self, pairs: list):
1615
description_prefix = "Checking validity: "
1716

1817
tqdm_iterator = tqdm.tqdm(pairs, desc=description_prefix)
19-
for video_index, (video_file_path, json_file_path) in enumerate(tqdm_iterator):
18+
for video_index, (video_file_path, annotation_file_path) in enumerate(tqdm_iterator):
2019
tqdm_iterator.set_description(description_prefix + video_file_path)
2120

2221
label = self.extract_label(video_file_path)
@@ -28,14 +27,13 @@ def create_valid_indices(self, pairs: list):
2827

2928
valid_frame_indices = []
3029

31-
with open(json_file_path) as file:
32-
annotations = json.load(file)
30+
annotations = self.parse_annotation(annotation_file_path)
3331

34-
for frame_index, annotation in enumerate(annotations):
35-
if self.is_valid_annotation(video_width, video_height, annotation) is False:
36-
continue
32+
for frame_index, annotation in enumerate(annotations):
33+
if self.is_valid_annotation(video_width, video_height, annotation) is False:
34+
continue
3735

38-
valid_frame_indices.append(frame_index)
36+
valid_frame_indices.append(frame_index)
3937

4038
if len(valid_frame_indices) == 0:
4139
continue
@@ -57,17 +55,17 @@ def get_train_filename(self, label: Union[int, str], index: int):
5755
video_indices = sorted(list(self.train_valid_indices[label].keys()))
5856
video_index = video_indices[index]
5957

60-
movie_file_path, json_file_path = self.train_pairs[video_index]
58+
movie_file_path, annotation_file_path = self.train_pairs[video_index]
6159

62-
return movie_file_path, json_file_path
60+
return movie_file_path, annotation_file_path
6361

6462
def get_validation_filename(self, label: Union[int, str], index: int):
6563
video_indices = sorted(list(self.validation_valid_indices[label].keys()))
6664
video_index = video_indices[index]
6765

68-
movie_file_path, json_file_path = self.validation_pairs[video_index]
66+
movie_file_path, annotation_file_path = self.validation_pairs[video_index]
6967

70-
return movie_file_path, json_file_path
68+
return movie_file_path, annotation_file_path
7169

7270
def get_train_datum(self, label, index):
7371
video_indices = sorted(list(self.train_valid_indices[label].keys()))
@@ -76,11 +74,10 @@ def get_train_datum(self, label, index):
7674
random.seed(None)
7775
frame_index = random.choice(self.train_valid_indices[label][video_index])
7876

79-
movie_file_path, json_file_path = self.train_pairs[video_index]
77+
movie_file_path, annotation_file_path = self.train_pairs[video_index]
8078

81-
with open(json_file_path) as json_file:
82-
annotations = json.load(json_file)
83-
annotation = annotations[frame_index]
79+
annotations = self.parse_annotation(annotation_file_path)
80+
annotation = annotations[frame_index]
8481

8582
video = cv2.VideoCapture(movie_file_path)
8683
video.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
@@ -96,11 +93,10 @@ def get_validation_datum(self, label, index):
9693
random.seed(self.random_salt + index)
9794
frame_index = random.choice(self.validation_valid_indices[label][video_index])
9895

99-
movie_file_path, json_file_path = self.validation_pairs[video_index]
96+
movie_file_path, annotation_file_path = self.validation_pairs[video_index]
10097

101-
with open(json_file_path) as json_file:
102-
annotations = json.load(json_file)
103-
annotation = annotations[frame_index]
98+
annotations = self.parse_annotation(annotation_file_path)
99+
annotation = annotations[frame_index]
104100

105101
video = cv2.VideoCapture(movie_file_path)
106102
video.set(cv2.CAP_PROP_POS_FRAMES, frame_index)

0 commit comments

Comments
 (0)