-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_alternatives.py
More file actions
282 lines (248 loc) · 7.92 KB
/
benchmark_alternatives.py
File metadata and controls
282 lines (248 loc) · 7.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import torch
import shutil
import glob
from ultralytics import YOLO
LABELS_SOURCE_DIR = "data/RTTS_Original_Standard/labels"
def clear_yolo_cache(data_dir):
"""
Force clear YOLO cache files to prevent reusing old dataset info.
"""
cache_files = glob.glob(os.path.join(data_dir, "**", "*.cache"), recursive=True)
for cache_file in cache_files:
try:
os.remove(cache_file)
print(f"🧹 Deleted cache: {cache_file}")
except Exception as e:
print(f"⚠️ Could not delete {cache_file}: {e}")
def ensure_labels_for_dataset(data_dir):
"""
Ensure labels directory exists and contains all required label files.
This is critical for YOLO validation to work correctly.
"""
labels_dir = os.path.join(data_dir, "labels")
if not os.path.exists(labels_dir):
os.makedirs(labels_dir)
print(f"📁 Created labels directory: {labels_dir}")
if not os.path.exists(LABELS_SOURCE_DIR):
print(f"❌ Source labels directory not found: {LABELS_SOURCE_DIR}")
return False
source_labels = set(os.listdir(LABELS_SOURCE_DIR))
existing_labels = set(os.listdir(labels_dir)) if os.path.exists(labels_dir) else set()
missing_labels = source_labels - existing_labels
if missing_labels:
print(f"📋 Copying {len(missing_labels)} missing label files...")
for label_file in missing_labels:
src = os.path.join(LABELS_SOURCE_DIR, label_file)
dst = os.path.join(labels_dir, label_file)
try:
shutil.copy2(src, dst)
except Exception as e:
print(f"⚠️ Failed to copy {label_file}: {e}")
print(f"✅ Labels synchronized for {data_dir}")
else:
print(f"✅ Labels already exist for {data_dir}")
return True
def benchmark_model_custom(data_dir, model_name, weights_path='yolov10m.pt'):
print(f"\n🚀 Benchmarking: {model_name}")
print(f"📂 Data: {data_dir}")
clear_yolo_cache(data_dir)
ensure_labels_for_dataset(data_dir)
val_txt_path = f"data/rtts_{model_name.lower().replace(' ', '_')}_val.txt"
try:
with open("data/rtts_original_val.txt", 'r', encoding='utf-8') as f:
orig_paths = f.readlines()
except UnicodeDecodeError:
with open("data/rtts_original_val.txt", 'r', encoding='gbk') as f:
orig_paths = f.readlines()
new_paths = []
missing_count = 0
for p in orig_paths:
p = p.strip()
if not p: continue
filename = os.path.basename(p)
new_p = os.path.join(data_dir, "images", filename)
new_p_abs = os.path.abspath(new_p)
if os.path.exists(new_p_abs):
new_paths.append(new_p_abs)
else:
missing_count += 1
if missing_count <= 5:
print(f"⚠️ Missing file: {new_p_abs}")
if missing_count > 0:
print(f"❌ Total missing files: {missing_count}")
if len(new_paths) == 0:
print("❌ No valid images found. Skipping benchmark.")
return None
print("📝 First 3 paths in val.txt:")
for i in range(min(3, len(new_paths))):
print(f" {new_paths[i]}")
with open(val_txt_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(new_paths))
yaml_path = f"data/rtts_{model_name.lower().replace(' ', '_')}.yaml"
yaml_content = f"""
path: .
train: {val_txt_path}
val: {val_txt_path}
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush
"""
with open(yaml_path, 'w', encoding='utf-8') as f:
f.write(yaml_content)
try:
model = YOLO(weights_path)
metrics = model.val(
data=yaml_path,
batch=16,
imgsz=640,
conf=0.001,
iou=0.6,
split='val',
project='runs/benchmark_alternatives',
name=model_name,
classes=[2, 3, 5],
verbose=False,
plots=False
)
ap50 = metrics.box.ap50
indices = metrics.ap_class_index
ap_dict = {idx: ap for idx, ap in zip(indices, ap50)}
return {
"Car": ap_dict.get(2, 0.0),
"Motorcycle": ap_dict.get(3, 0.0),
"Bus": ap_dict.get(5, 0.0),
"mAP": (ap_dict.get(2,0)+ap_dict.get(3,0)+ap_dict.get(5,0))/3
}
except Exception as e:
print(f"❌ Error: {e}")
return None
def benchmark_alternatives():
results = {}
# 1. DehazeNet (Pretrained YOLO)
print("--- Benchmarking DehazeNet (Pretrained YOLO) ---")
data_dir = "data/RTTS_Dehazed_DehazeNet"
# Verify data exists
if os.path.exists(data_dir):
res = benchmark_model_custom(data_dir, "yolo_dehazenet", "yolov10m.pt")
if res:
results["DehazeNet"] = res
else:
print(f"Skipping DehazeNet: {data_dir} not found")
# 2. DCP (Pretrained YOLO)
print("--- Benchmarking DCP (Pretrained YOLO) ---")
data_dir = "data/RTTS_Dehazed_DCP"
if os.path.exists(data_dir):
res = benchmark_model_custom(data_dir, "yolo_dcp", "yolov10m.pt")
if res:
results["DCP"] = res
else:
print(f"Skipping DCP: {data_dir} not found")
# 3. Load previous results for comparison (Hardcoded for speed, based on previous run)
# Foggy: 71.94%, AOD: 71.94%, Advanced: 71.80%
# results["Foggy"] = {'Car': 0.719, 'Bus': 0.527, 'Motorcycle': 0.623, 'mAP': 0.623}
# results["AOD-Net"] = {'Car': 0.719, 'Bus': 0.518, 'Motorcycle': 0.632, 'mAP': 0.623}
# results["AOD-Advanced"] = {'Car': 0.718, 'Bus': 0.515, 'Motorcycle': 0.627, 'mAP': 0.620}
# Re-run Foggy Baseline
print("--- Benchmarking Foggy (Baseline) ---")
res_foggy = benchmark_model_custom("data/RTTS_Original_Standard", "yolo_foggy", "yolov10m.pt")
if res_foggy: results["Foggy"] = res_foggy
# Re-run AOD-Net
print("--- Benchmarking AOD-Net ---")
res_aod = benchmark_model_custom("data/RTTS_Dehazed_Original_AOD", "yolo_aod", "yolov10m.pt")
if res_aod: results["AOD-Net"] = res_aod
# Re-run AOD-Advanced
print("--- Benchmarking AOD-Advanced ---")
res_adv = benchmark_model_custom("data/RTTS_Dehazed_Advanced_Perceptual", "yolo_advanced", "yolov10m.pt")
if res_adv: results["AOD-Advanced"] = res_adv
# Print Comparison Table
print("\n\n")
print(f"### Table: Comparison of Dehazing Algorithms (Pretrained YOLOv10m)")
print("| Method | Car | Bus | Motorcycle | mAP (3-Class) |")
print("| :--- | :--- | :--- | :--- | :--- |")
for name, res in results.items():
car_pct = res['Car'] * 100
bus_pct = res['Bus'] * 100
moto_pct = res['Motorcycle'] * 100
map_pct = res['mAP'] * 100
print(f"| {name} | {car_pct:.3f}% | {bus_pct:.3f}% | {moto_pct:.3f}% | {map_pct:.3f}% |")
if __name__ == "__main__":
benchmark_alternatives()