-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_car_only.py
More file actions
120 lines (100 loc) · 5.15 KB
/
benchmark_car_only.py
File metadata and controls
120 lines (100 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import pandas as pd
from ultralytics import YOLO, settings
# Fix for OpenMP error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 强制修正路径
settings.update({'datasets_dir': os.getcwd()})
MODELS = ['yolov10m']
DATASETS = {
'Original': 'data/RTTS_Original_Standard',
'AOD-Net': 'data/RTTS_Dehazed_AOD',
'AOD-Advanced': 'data/RTTS_Dehazed_Advanced',
'Advanced (Perceptual)': 'data/RTTS_Dehazed_Advanced_Perceptual'
}
# 目标类别: Car (2), Motorcycle (3), Bus (5)
# RTTS 没有 Truck (7)
TARGET_CLASSES = [2, 3, 5]
CLASS_NAMES_MAP = {2: 'Car', 3: 'Motorcycle', 5: 'Bus'}
def benchmark_car_only():
results = []
print("🚀 Starting Vehicle Benchmark (Car, Bus, Motorcycle)...")
for model_name in MODELS:
print(f"\n>>> Testing {model_name}")
# 加载模型
model = YOLO(f"weights/{model_name}.pt") if os.path.exists(f"weights/{model_name}.pt") else YOLO(f"{model_name}.pt")
for data_name, data_root in DATASETS.items():
# 构造 yaml (相对路径)
safe_data_name = data_name.lower().replace(' ', '_').replace('(', '').replace(')', '')
yaml_path = f"data/rtts_{safe_data_name}_vehicle.yaml"
with open(yaml_path, 'w') as f:
f.write("path: .\n")
if data_name == 'Original':
f.write(f"train: {data_root}/images\n")
f.write(f"val: {data_root}/images\n")
else:
f.write(f"train: {data_root}/images\n")
f.write(f"val: {data_root}/images\n")
f.write("names:\n")
# 写全 names 以防万一
COCO_NAMES = {
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck',
# ... (省略部分,但这部分代码最好完整保留)
8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear',
22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase',
29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle',
40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana',
47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table',
61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone',
68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}
for k, v in COCO_NAMES.items():
f.write(f" {k}: {v}\n")
try:
metrics = model.val(
data=yaml_path,
batch=16,
imgsz=640,
conf=0.001,
iou=0.6,
split='val',
project='runs/benchmark_vehicle',
name=f"{model_name}_{data_name}",
classes=TARGET_CLASSES, # 只验证 Car (2) 和 Bus (5)
verbose=False,
plots=False
)
# 提取每一类的 AP@50
# metrics.box.ap50 只包含被评估的类别
# 我们需要结合 metrics.ap_class_index 来对应 ID
ap50_values = metrics.box.ap50
class_indices = metrics.ap_class_index
# 建立 ID -> AP 映射
id_to_ap = {idx: ap for idx, ap in zip(class_indices, ap50_values)}
car_ap = id_to_ap.get(2, 0.0)
motorcycle_ap = id_to_ap.get(3, 0.0)
bus_ap = id_to_ap.get(5, 0.0)
results.append({
'Model': model_name,
'Dataset': data_name,
'Car_AP': car_ap,
'Motorcycle_AP': motorcycle_ap,
'Bus_AP': bus_ap
})
print(f" -> Car: {car_ap:.4f}, Motor: {motorcycle_ap:.4f}, Bus: {bus_ap:.4f}")
except Exception as e:
print(f"Error: {e}")
# 生成报告
df = pd.DataFrame(results)
print("\n" + "="*40)
print(" 🚗 VEHICLE BENCHMARK 🚗 ")
print("="*40)
print(df)
df.to_csv('benchmark_vehicle_only.csv')
print("\nSaved to benchmark_vehicle_only.csv")
if __name__ == "__main__":
benchmark_car_only()