-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
180 lines (154 loc) · 6.47 KB
/
main.py
File metadata and controls
180 lines (154 loc) · 6.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
main.py — 主程式入口
整合整個 pipeline,提供三個執行模式:
python main.py --mode preprocess → 步驟1:萃取幀快取
python main.py --mode ssl → 步驟2:SSL 預訓練
python main.py --mode finetune → 步驟3:Fine-tuning
python main.py --mode online → 步驟4:模擬 online learning
完整 pipeline 流程圖:
原始影片 (.mov/.mp4)
│
▼ [preprocess]
幀快取 (cache/frames/*.pkl)
│
▼ [ssl]
SSL 模型 (checkpoints/ssl_best.pth)
│ SimCLR 對比學習,完全無標籤
▼ [finetune]
分類器 (checkpoints/classifier_*.pth)
│ 凍結 encoder + 訓練分類頭
▼ [online]
個人化模型 (checkpoints/user_*_adapter.pth)
↑↓ 即時更新 adapter,持續個人化
"""
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config import DATA_DIR, CACHE_DIR, CKPT_DIR, DEVICE
def run_preprocess(args):
"""
步驟1:從原始影片萃取人臉幀並快取到磁碟
這一步只需要跑一次,之後訓練都讀快取。
"""
print("\n[Step 1] 資料前處理:萃取人臉幀")
from data.video_utils import scan_dataset, build_frame_cache
records = scan_dataset()
index_path = build_frame_cache(records, force_rebuild=args.force_rebuild)
print(f"快取索引儲存至: {index_path}")
def run_ssl(args):
"""
步驟2:SSL 預訓練(SimCLR)
以無標籤方式學習人臉疲勞特徵的通用表示。
"""
print("\n[Step 2] SSL 預訓練 (SimCLR)")
from data.video_utils import load_cache_index
from training.ssl_pretrain import ssl_pretrain
cache_index = load_cache_index()
print(f"[Device] {DEVICE}")
ssl_kwargs = dict(cache_index=cache_index,
resume_ckpt=args.resume if args.resume else None)
if args.epochs:
ssl_kwargs["epochs"] = args.epochs
ssl_ckpt = ssl_pretrain(**ssl_kwargs)
print(f"SSL 預訓練完成,Checkpoint: {ssl_ckpt}")
def run_finetune(args):
"""
步驟3:Fine-tuning
使用 SSL 學到的 encoder 加上分類頭,進行監督式訓練。
"""
print("\n[Step 3] Fine-tuning")
from data.video_utils import load_cache_index
from training.finetune import finetune, leave_one_out_evaluation, train_general_model
cache_index = load_cache_index()
ssl_ckpt = args.ssl_ckpt or os.path.join(CKPT_DIR, "ssl_best.pth")
if not os.path.exists(ssl_ckpt):
print(f"[ERROR] 找不到 SSL checkpoint: {ssl_ckpt}")
print("請先執行: python main.py --mode ssl")
return
if args.general:
# 用全部資料訓練部署用通用模型
gm_kwargs = dict(stage=args.stage)
if args.epochs:
gm_kwargs["epochs"] = args.epochs
train_general_model(ssl_ckpt, cache_index, **gm_kwargs)
elif args.loso:
# Leave-one-subject-out 完整評估
results = leave_one_out_evaluation(ssl_ckpt, cache_index,
stage=args.stage)
else:
# 單一 subject 測試
ft_kwargs = dict(test_subject=args.test_subject, stage=args.stage)
if args.epochs:
ft_kwargs["epochs"] = args.epochs
finetune(ssl_ckpt, cache_index, **ft_kwargs)
def run_online(args):
"""
步驟4:Human-in-the-Loop 個人化更新
讀取駕駛者按鈕修正資料,更新個人化 adapter。
通常在行程結束後執行:
python main.py --mode online --user_id driver_01
"""
print("\n[Step 4] Online Learning 個人化更新")
from training.online_learning import run_online_update
ssl_ckpt = args.ssl_ckpt or os.path.join(CKPT_DIR, "ssl_best.pth")
general_ckpt = args.finetune_ckpt or os.path.join(CKPT_DIR, "classifier_general.pth")
if not os.path.exists(ssl_ckpt):
print(f"[ERROR] 找不到 SSL checkpoint: {ssl_ckpt}")
return
if not os.path.exists(general_ckpt):
print(f"[ERROR] 找不到通用模型: {general_ckpt}")
print("請先執行: python main.py --mode finetune --general")
return
run_online_update(
user_id=args.user_id,
ssl_ckpt=ssl_ckpt,
general_ckpt=general_ckpt,
keep_corrections=args.keep_corrections,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="疲勞駕駛偵測 — SSL + Online Learning"
)
parser.add_argument(
"--mode",
choices=["preprocess", "ssl", "finetune", "online"],
required=True,
help="執行模式"
)
# 通用參數
parser.add_argument("--epochs", type=int, default=None,
help="覆蓋 config 中的 epoch 數")
parser.add_argument("--test_subject", type=str, default="01",
help="Fine-tuning / Online 時的測試受試者 ID")
# SSL 相關
parser.add_argument("--resume", type=str, default=None,
help="從 checkpoint 繼續 SSL 訓練")
# Fine-tuning 相關
parser.add_argument("--ssl_ckpt", type=str, default=None,
help="SSL checkpoint 路徑")
parser.add_argument("--stage",
choices=["linear", "full"], default="linear",
help="linear=只訓練頭, full=解凍 encoder")
parser.add_argument("--loso", action="store_true",
help="執行完整 leave-one-subject-out 評估")
parser.add_argument("--general", action="store_true",
help="用全部資料訓練通用部署模型")
# Online 相關
parser.add_argument("--finetune_ckpt", type=str, default=None,
help="Fine-tune checkpoint 路徑(online 時預設用 classifier_general.pth)")
parser.add_argument("--user_id", type=str, default="driver_01",
help="使用者 ID,用於儲存個人化 adapter(預設 driver_01)")
parser.add_argument("--keep_corrections", action="store_true",
help="更新後保留修正資料(預設更新後清空)")
# 其他
parser.add_argument("--force_rebuild", action="store_true",
help="強制重建幀快取(即使已存在)")
args = parser.parse_args()
mode_map = {
"preprocess": run_preprocess,
"ssl": run_ssl,
"finetune": run_finetune,
"online": run_online,
}
mode_map[args.mode](args)