-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathreproduce_results.py
87 lines (68 loc) · 2.87 KB
/
reproduce_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import glob, os
import sed_eval
import argparse
import dcase_util
parser = argparse.ArgumentParser()
parser.add_argument("-dataset_path", required=True, type=str, help="root path for PodcastFillers dataset")
args = parser.parse_args()
def sed_eval_reproduce(gt_folder, est_folder, t_collar=0.1):
"""
Args:
gt_folder (str): folder path for ground truth in sed_eval supported txt files
est_folder (str): folder path for AVC-FillerNet predicitons with sed_eval supported txt files
t_collar (float, optional): collar size for sed_eval evalution, in the paper, we use 0.1s
"""
# prepare fileList dictionary for sed_eval
filelist_dict = []
gt_filelist = glob.glob(os.path.join(gt_folder, "*.txt"))
for gt_file in gt_filelist:
file_dict = {}
_, filename = os.path.split(gt_file)
est_file = os.path.join(est_folder, filename)
file_dict["reference_file"] = gt_file
file_dict["estimated_file"] = est_file
filelist_dict.append(file_dict)
data = []
# Get used event labels
all_data = dcase_util.containers.MetaDataContainer()
for file_pair in filelist_dict:
reference_event_list = sed_eval.io.load_event_list(
filename=file_pair["reference_file"]
)
estimated_event_list = sed_eval.io.load_event_list(
filename=file_pair["estimated_file"]
)
data.append(
{
"reference_event_list": reference_event_list,
"estimated_event_list": estimated_event_list,
}
)
all_data += reference_event_list
event_labels = all_data.unique_event_labels
# Start evaluating
# Create metrics classes, define parameters
segment_based_metrics = sed_eval.sound_event.SegmentBasedMetrics(
event_label_list=event_labels, time_resolution=0.1
)
event_based_metrics = sed_eval.sound_event.EventBasedMetrics(
event_label_list=event_labels, t_collar=t_collar
)
# Go through files
for file_pair in data:
segment_based_metrics.evaluate(
reference_event_list=file_pair["reference_event_list"],
estimated_event_list=file_pair["estimated_event_list"],
)
event_based_metrics.evaluate(
reference_event_list=file_pair["reference_event_list"],
estimated_event_list=file_pair["estimated_event_list"],
)
print(event_based_metrics)
print(segment_based_metrics)
if __name__ == "__main__":
for i in range(2):
print('Reproducing Table ' + str(i) + ':')
gt_folder = os.path.join(args.dataset_pathh, "meta_data", "episode_sed_eval_paper", "ground_truth", "Table" + str(i))
est_folder = os.path.join(args.dataset_path, "meta_data", "episode_sed_eval_paper", "AVCFillerNet_predictions", "Table" + str(i))
sed_eval_reproduce(gt_folder, est_folder)