-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfull_benchmark_visualization.py
More file actions
201 lines (166 loc) · 8.08 KB
/
full_benchmark_visualization.py
File metadata and controls
201 lines (166 loc) · 8.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import shutil
import yaml
import glob
import xml.etree.ElementTree as ET
from tqdm import tqdm
import cv2
import torch
from main import run_dehaze
from ultralytics import YOLO
# --- 配置 ---
RTTS_ROOT = 'data/RTTS'
DEHAZED_ROOT = 'data/RTTS_Dehazed_AOD' # 改为 AOD 目录
MODEL_WEIGHTS = 'weights/AOD_final.pth' # 使用原版 AOD 权重
YOLO_WEIGHTS = 'weights/yolov10m.pt'
CLASSES = ['car', 'bus', 'truck', 'motorcycle', 'bicycle'] # RTTS/RESIDE 类别可能略有不同,需确认
# RTTS XML 类别通常是: person, bicycle, car, motorbike, bus, truck
# 我们需要映射到 COCO 类别 ID
# COCO: 0:person, 1:bicycle, 2:car, 3:motorcycle, 5:bus, 7:truck
CLASS_MAP = {
'person': 0,
'bicycle': 1,
'car': 2,
'motorbike': 3,
'motorcycle': 3, # 兼容写法
'bus': 5,
'truck': 7
}
def convert_xml_to_txt():
print("正在转换 XML 标签为 YOLO TXT 格式...")
xml_dir = os.path.join(RTTS_ROOT, 'Annotations')
txt_dir = os.path.join(RTTS_ROOT, 'labels')
if not os.path.exists(txt_dir):
os.makedirs(txt_dir)
xml_files = glob.glob(os.path.join(xml_dir, '*.xml'))
for xml_file in tqdm(xml_files):
tree = ET.parse(xml_file)
root = tree.getroot()
# 获取图片尺寸
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
txt_filename = os.path.basename(xml_file).replace('.xml', '.txt')
txt_path = os.path.join(txt_dir, txt_filename)
with open(txt_path, 'w') as f:
for obj in root.findall('object'):
cls_name = obj.find('name').text.lower()
if cls_name not in CLASS_MAP:
continue
cls_id = CLASS_MAP[cls_name]
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text),
float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
# 归一化 xywh
bb = ((b[0] + b[1]) / 2.0 / w, (b[2] + b[3]) / 2.0 / h,
(b[1] - b[0]) / w, (b[3] - b[2]) / h)
f.write(f"{cls_id} {bb[0]:.6f} {bb[1]:.6f} {bb[2]:.6f} {bb[3]:.6f}\n")
def process_dehazing():
print("正在对 RTTS 全量去雾 (AOD-Advanced)...")
src_img_dir = os.path.join(RTTS_ROOT, 'JPEGImages')
dst_img_dir = os.path.join(DEHAZED_ROOT, 'images')
dst_label_dir = os.path.join(DEHAZED_ROOT, 'labels')
if not os.path.exists(dst_img_dir):
os.makedirs(dst_img_dir)
if not os.path.exists(dst_label_dir):
os.makedirs(dst_label_dir)
# 1. 复制标签 (直接从 RTTS/labels 复制)
print("复制标签文件...")
src_label_dir = os.path.join(RTTS_ROOT, 'labels')
if not os.path.exists(src_label_dir):
convert_xml_to_txt() # 如果不存在则先转换
# 复制所有 txt
for txt_file in glob.glob(os.path.join(src_label_dir, '*.txt')):
shutil.copy(txt_file, dst_label_dir)
# 2. 图片去雾
images = glob.glob(os.path.join(src_img_dir, '*.*'))
# 为了演示,我们这里只处理前 100 张?不,用户要求全量。
# 如果全量太慢,可以先跑一部分。这里我们跑全量。
print(f"共发现 {len(images)} 张图片,开始处理...")
for img_path in tqdm(images):
img_name = os.path.basename(img_path)
save_path = os.path.join(dst_img_dir, img_name)
# 如果已经存在,跳过 (支持断点续传)
if os.path.exists(save_path):
continue
try:
# 调用 main.py 中的 run_dehaze (返回 RGB numpy)
# 我们需要保存它
# 使用 'aod' 模型名
dehazed = run_dehaze(img_path, 'aod', MODEL_WEIGHTS, output_path=None)
if dehazed is not None:
# RGB -> BGR for opencv save
dehazed_bgr = cv2.cvtColor(dehazed, cv2.COLOR_RGB2BGR)
cv2.imwrite(save_path, dehazed_bgr)
except Exception as e:
print(f"Error processing {img_name}: {e}")
def run_yolo_val():
print("开始运行 YOLOv10 验证...")
# 1. 创建 dataset.yaml
yaml_path = os.path.join('data', 'rtts_dehazed_aod.yaml')
# 修正: 使用完整的 COCO 80 类,避免索引越界
# YOLO 权重是基于 80 类训练的,所以 yaml 里必须有 80 个类
# 即使我们的数据只有其中 6 个,标签里的 ID (如 7) 也是合法的
COCO_NAMES = {
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck',
8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear',
22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase',
29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle',
40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana',
47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table',
61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone',
68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}
dataset_cfg = {
'path': os.path.abspath(DEHAZED_ROOT), # 绝对路径更安全
'train': 'images', # 只有验证集,这里随便填,或者填一样
'val': 'images',
'names': COCO_NAMES
}
# 注意:这里我们只保留了 RTTS 用到的几个 COCO 类别
# 但 YOLOv10m 是在 COCO 80类上训练的。
# 验证时,我们需要确保 dataset.yaml 的 names 索引与 COCO 一致,或者使用 coco.yaml 过滤
# 最简单的办法:直接用标准的 coco.yaml 格式,但指向我们的数据
# 修正 dataset_cfg: 使用标准 COCO 索引
# 我们只关心 RTTS 里有的那些类。
# 为了方便,我们直接生成一个只包含我们数据的 yaml,但索引 ID 必须对齐 COCO
# 上面的 names 字典就是对齐的 (0, 1, 2, 3, 5, 7)
with open(yaml_path, 'w') as f:
yaml.dump(dataset_cfg, f)
# 2. 运行验证
# 使用通用 YOLO 类,加载 v10 权重即可
model = YOLO(YOLO_WEIGHTS)
# 验证
# save_json=True 会保存预测结果 json
# plots=True 会画出 PR 曲线等
# classes 参数: 指定只验证特定的类别 ID
# 0:person, 1:bicycle, 2:car, 3:motorcycle, 5:bus, 7:truck
# 既然 RTTS 主要是车辆检测,我们最好只关注车辆类别,或者确保 GT 里包含 person
# 如果 RTTS 的 person 标注不全,建议去掉 0
# 这里我们保留 0, 1, 2, 3, 5, 7,但如果想更纯粹的车辆 mAP,可以去掉 0
TARGET_CLASSES = [0, 1, 2, 3, 5, 7]
metrics = model.val(
data=yaml_path,
batch=8,
imgsz=640,
conf=0.001, # 验证时通常用极低置信度以计算 mAP
iou=0.6,
split='val',
project='runs/detect',
name='val_rtts_aod', # 改为 AOD 验证名
classes=TARGET_CLASSES, # 关键修复: 只统计这些类,忽略其他 COCO 杂类
save=True,
plots=True
)
print(f"验证完成!结果保存在 runs/detect/val_rtts_aod")
print(f"mAP50: {metrics.box.map50}")
print(f"mAP50-95: {metrics.box.map}")
if __name__ == "__main__":
# 1. 准备数据 (XML->TXT, 去雾)
process_dehazing()
# 2. 运行验证
run_yolo_val()