-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_final_tables.py
More file actions
296 lines (254 loc) · 8.65 KB
/
generate_final_tables.py
File metadata and controls
296 lines (254 loc) · 8.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import torch
from ultralytics import YOLO
# Fix for OpenMP error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def benchmark_model_custom(data_dir, model_name, weights_path='yolov10m.pt'):
print(f"\n🚀 Benchmarking: {model_name}")
print(f"📂 Data: {data_dir}")
# 1. Create a temporary YAML for this dataset if needed
# We can reuse the existing train/val split logic by just mapping the paths
# But since we generated 'data/rtts_original_val.txt' which contains relative paths like 'data/RTTS_Original_Standard/images/...'
# We need to create a similar txt for the new dataset 'data/RTTS_Dehazed_Original_AOD'
# Let's generate the val.txt first
val_txt_path = f"data/rtts_{model_name.lower().replace(' ', '_')}_val.txt"
# Read original val txt to get file names
with open("data/rtts_original_val.txt", 'r') as f:
orig_paths = f.readlines()
# Replace the root directory
# Original: data/RTTS_Original_Standard/images/name.png
# New: data/RTTS_Dehazed_Original_AOD/images/name.png
# Check what the actual path in orig_paths looks like
# It might be relative or absolute.
new_paths = []
for p in orig_paths:
p = p.strip()
if not p: continue
# We know the structure is .../images/filename.png
filename = os.path.basename(p)
# Construct new path
new_p = os.path.join(data_dir, "images", filename)
new_paths.append(new_p)
with open(val_txt_path, 'w') as f:
f.write('\n'.join(new_paths))
print(f"Created val list: {val_txt_path}")
# Create YAML
yaml_path = f"data/rtts_{model_name.lower().replace(' ', '_')}.yaml"
with open("data/rtts_original_train.yaml", 'r') as f:
base_yaml = f.read()
# Replace val path
# We can just write a new simple yaml
yaml_content = f"""
path: .
train: {val_txt_path} # We don't need train for benchmark, but YOLO needs the key
val: {val_txt_path}
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush
"""
with open(yaml_path, 'w') as f:
f.write(yaml_content)
# Run Validation
try:
model = YOLO(weights_path)
metrics = model.val(
data=yaml_path,
batch=16,
imgsz=640,
conf=0.001,
iou=0.6,
split='val',
project='runs/benchmark_tables',
name=model_name,
classes=[2, 3, 5], # Car, Motor, Bus
verbose=False,
plots=False
)
ap50 = metrics.box.ap50
indices = metrics.ap_class_index
ap_dict = {idx: ap for idx, ap in zip(indices, ap50)}
return {
"Car": ap_dict.get(2, 0.0),
"Motorcycle": ap_dict.get(3, 0.0),
"Bus": ap_dict.get(5, 0.0),
"mAP": (ap_dict.get(2,0)+ap_dict.get(3,0)+ap_dict.get(5,0))/3
}
except Exception as e:
print(f"❌ Error: {e}")
return None
def generate_final_tables():
results = {}
# 1. YOLOv10 (Baseline on Foggy) - Pretrained
# We already have this data, but running it ensures consistency
# Or we can just use the known values to save time if confirmed.
# User said "re-run if needed". Let's be safe and quick-run or reuse.
# Given we just ran it in previous step (71.94%), we can hardcode OR re-run.
# Let's re-run "YOLO + Original AOD" since that's NEW.
# Others we can reuse from benchmark_extended.py logs if we trust them.
# Let's run the MISSING piece: YOLO + Original AOD (Pretrained)
print("--- 1. Benchmarking YOLO + Original AOD ---")
res_aod = benchmark_model_custom(
"data/RTTS_Dehazed_Original_AOD",
"yolo_aod_original",
"yolov10m.pt"
)
if res_aod:
results["YOLO+AOD"] = res_aod
# We also need "YOLO+AOD-Advanced" (Pretrained)
# We have this from previous run: Car 71.80%.
# But let's verify if user wants "Pretrained" or "Finetuned" for Table 1.
# Table 1 implies "Comparison of Dehazing Algorithms" usually without fine-tuning YOLO for each.
# So we use Pretrained YOLO for all rows in Table 1.
print("--- 2. Benchmarking YOLO + AOD-Advanced (Pretrained) ---")
res_adv = benchmark_model_custom(
"data/RTTS_Dehazed_Advanced_Perceptual",
"yolo_aod_advanced",
"yolov10m.pt"
)
if res_adv:
results["YOLO+Advanced"] = res_adv
# Baseline YOLO (Foggy)
print("--- 3. Benchmarking YOLO (Foggy) ---")
res_base = benchmark_model_custom(
"data/RTTS_Original_Standard",
"yolo_foggy",
"yolov10m.pt"
)
if res_base:
results["YOLO"] = res_base
# --- 4. Benchmarking Improved YOLO (Ours Finetuned) ---
print("\n🚀 Benchmarking Improved YOLO (Ours Finetuned)")
# Path to best weights from extended training
weights_ours = "runs/train_extended/yolov10m_perceptual_extended/weights/best.pt"
if not os.path.exists(weights_ours):
# Fallback to interrupted run
weights_ours = "runs/detect/runs/train_comparison/yolov10m_rtts_perceptual_train2/weights/best.pt"
if os.path.exists(weights_ours):
print(f"Using weights: {weights_ours}")
# Need to use the correct data YAML (Dehazed)
# We can reuse the one created for 'yolo_aod_advanced' if consistent
# Or just pass the directory path and let helper create it
res_ours = benchmark_model_custom(
"data/RTTS_Dehazed_Advanced_Perceptual",
"yolo_ours_finetuned",
weights_ours
)
if res_ours:
results["Ours"] = res_ours
else:
print("❌ Error: Could not find trained weights for Ours.")
# --- Generate Markdown Tables ---
# Table 1: Ablation / Algorithm Comparison
# Rows: YOLO, YOLO+AOD, YOLO+Ours
print("\n\n")
print("### Table 1: Average Detection Precision (Ablation Study)")
print("| Methods | Car | Bus | Motorcycle | mAP (3-Class) |")
print("| :--- | :--- | :--- | :--- | :--- |")
# We want to show:
# 1. YOLO (Foggy Baseline)
# 2. YOLO + AOD-Net (Original AOD)
# 3. YOLO + AOD-Advanced (Ours Dehazed Only) -> Corresponds to "本文算法" in Fig 1 context of dehazing
if "YOLO" in results:
r = results["YOLO"]
print(f"| YOLOv10 | {r['Car']:.3f} | {r['Bus']:.3f} | {r['Motorcycle']:.3f} | {r['mAP']:.3f} |")
if "YOLO+AOD" in results:
r = results["YOLO+AOD"]
print(f"| YOLOv10 + AOD-Net | {r['Car']:.3f} | {r['Bus']:.3f} | {r['Motorcycle']:.3f} | {r['mAP']:.3f} |")
if "YOLO+Advanced" in results:
r = results["YOLO+Advanced"]
print(f"| YOLOv10 + 本文算法 (Ours) | {r['Car']:.3f} | {r['Bus']:.3f} | {r['Motorcycle']:.3f} | {r['mAP']:.3f} |")
# Table 2: Method Comparison (Before vs After)
# Rows: YOLOv8, Improved YOLOv8
print("\n\n")
print("### Table 2: Average Detection Precision (Final Method Comparison)")
print("| Methods | Car | Bus | Motorcycle | mAP (3-Class) |")
print("| :--- | :--- | :--- | :--- | :--- |")
# Row 1: YOLO (Baseline Foggy)
if "YOLO" in results:
r = results["YOLO"]
print(f"| YOLOv10 | {r['Car']:.3f} | {r['Bus']:.3f} | {r['Motorcycle']:.3f} | {r['mAP']:.3f} |")
# Row 2: Improved YOLO (Ours Finetuned)
if "Ours" in results:
r = results["Ours"]
print(f"| 改进的 YOLOv10 (Ours) | {r['Car']:.3f} | {r['Bus']:.3f} | {r['Motorcycle']:.3f} | {r['mAP']:.3f} |")
print("\nDone.")
if __name__ == "__main__":
generate_final_tables()