Skip to content

Commit ad2d546

Browse files
committed
Update I3D detector
Clean code format and update config file for training
1 parent fa251e1 commit ad2d546

File tree

2 files changed

+25
-56
lines changed

2 files changed

+25
-56
lines changed

training/config/detector/i3d.yaml

+8-21
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,25 @@
1-
# log dir
2-
log_dir: /data/home/zhiyuanyan/logs/i3d
3-
41
# model setting
5-
pretrained: /data/home/zhiyuanyan/torch_ckpts/I3D_8x8_R50.pth # path to a pre-trained model, if using one
2+
pretrained: training/pretrained/I3D_8x8_R50.pth # path to a pre-trained model, if using one
63
model_name: i3d # model name
7-
backbone_name: xception # backbone name
8-
9-
#backbone setting
10-
backbone_config:
11-
mode: original
12-
num_classes: 1
13-
inc: 3
14-
dropout: false
154

165
# dataset
176
all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
187
train_dataset: [FaceForensics++]
198
test_dataset: [Celeb-DF-v2]
209

2110
compression: c23 # compression-level for videos
22-
train_batchSize: 8 # training batch size
23-
test_batchSize: 8 # test batch size
11+
train_batchSize: 32 # training batch size
12+
test_batchSize: 32 # test batch size
2413
workers: 8 # number of data loading workers
2514
frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
2615
resolution: 224 # resolution of output image to network
2716
with_mask: false # whether to include mask information in the input
2817
with_landmark: false # whether to include facial landmark information in the input
2918
video_mode: True # whether to use video-level data
30-
clip_size: 16 # number of frames in each clip
31-
19+
clip_size: 8 # number of frames in each clip
3220

3321
# data augmentation
34-
use_data_augmentation: true # Add this flag to enable/disable data augmentation
22+
use_data_augmentation: false # Add this flag to enable/disable data augmentation
3523
data_aug:
3624
flip_prob: 0.5
3725
rotate_prob: 0.5
@@ -45,8 +33,8 @@ data_aug:
4533
quality_upper: 100
4634

4735
# mean and std for normalization
48-
mean: [0.5, 0.5, 0.5]
49-
std: [0.5, 0.5, 0.5]
36+
mean: [0.485, 0.456, 0.406]
37+
std: [0.229, 0.224, 0.225]
5038

5139
# optimizer config
5240
optimizer:
@@ -66,7 +54,7 @@ optimizer:
6654

6755
# training config
6856
lr_scheduler: null # learning rate scheduler
69-
nEpochs: 30 # number of epochs to train for
57+
nEpochs: 100 # number of epochs to train for
7058
start_epoch: 0 # manual epoch number (useful for restarts)
7159
save_epoch: 1 # interval epochs for saving models
7260
rec_iter: 100 # interval iterations for recording
@@ -83,6 +71,5 @@ losstype: null
8371
metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
8472

8573
# cuda
86-
8774
cuda: true # whether to use CUDA acceleration
8875
cudnn: true # whether to use CuDNN for convolution operations

training/detectors/i3d_detector.py

+17-35
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
OUTPUT_DIR: .
5858
"""
5959

60-
6160
'''
6261
# author: Zhiyuan Yan
6362
@@ -85,31 +84,16 @@
8584
}
8685
'''
8786

88-
import os
89-
import datetime
9087
import logging
91-
import numpy as np
92-
from sklearn import metrics
93-
from typing import Union
94-
from collections import defaultdict
95-
96-
import torch
97-
import torch.nn as nn
98-
import torch.nn.functional as F
99-
import torch.optim as optim
100-
from torch.nn import DataParallel
101-
from torch.utils.tensorboard import SummaryWriter
102-
103-
from metrics.base_metrics_class import calculate_metrics_for_train
88+
import os
89+
import sys
10490

105-
from .base_detector import AbstractDetector
10691
from detectors import DETECTOR
107-
from networks import BACKBONE
10892
from loss import LOSSFUNC
93+
from metrics.base_metrics_class import calculate_metrics_for_train
10994

95+
from .base_detector import AbstractDetector
11096

111-
import os
112-
import sys
11397
current_file_path = os.path.abspath(__file__)
11498
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
11599
project_root_dir = os.path.dirname(parent_dir)
@@ -120,14 +104,13 @@
120104
from .utils.slowfast.models.video_model_builder import ResNet as ResNetOri
121105
from .utils.slowfast.config.defaults import get_cfg
122106
from torch import nn
123-
import random
124-
125107

126108
random_select = True
127109
no_time_pool = True
128110

129111
logger = logging.getLogger(__name__)
130112

113+
131114
@DETECTOR.register_module(module_name='i3d')
132115
class I3DDetector(AbstractDetector):
133116
def __init__(self, config):
@@ -137,7 +120,7 @@ def __init__(self, config):
137120
cfg.NUM_GPUS = 1
138121
cfg.TEST.BATCH_SIZE = 1
139122
cfg.TRAIN.BATCH_SIZE = 1
140-
cfg.DATA.NUM_FRAMES = 16
123+
cfg.DATA.NUM_FRAMES = config['clip_size']
141124
self.resnet = ResNetOri(cfg)
142125
if config['pretrained'] is not None:
143126
print(f"loading pretrained model from {config['pretrained']}")
@@ -150,45 +133,44 @@ def __init__(self, config):
150133
self.resnet.load_state_dict(modified_weights, strict=True)
151134

152135
self.loss_func = nn.BCELoss() # The output of the model is a probability value between 0 and 1 (haved used sigmoid)
153-
136+
154137
def build_backbone(self, config):
155138
pass
156-
139+
157140
def build_loss(self, config):
158141
# prepare the loss function
159142
loss_class = LOSSFUNC[config['loss_func']]
160143
loss_func = loss_class()
161144
return loss_func
162-
145+
163146
def features(self, data_dict: dict) -> torch.tensor:
164-
inputs = [data_dict['image'].permute(0,2,1,3,4)]
147+
inputs = [data_dict['image'].permute(0, 2, 1, 3, 4)]
165148
pred = self.resnet(inputs)
166-
output = {}
167-
output["final_output"] = pred
149+
output = {"final_output": pred}
150+
168151
return output["final_output"]
169152

170153
def classifier(self, features: torch.tensor):
171154
pass
172-
155+
173156
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
174157
label = data_dict['label'].float()
175158
pred = pred_dict['cls'].view(-1)
176159
loss = self.loss_func(pred, label)
177160
loss_dict = {'overall': loss}
161+
178162
return loss_dict
179-
163+
180164
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
181165
label = data_dict['label']
182166
pred = pred_dict['cls']
183-
# compute metrics for batch data
184167
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
185168
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
169+
186170
return metric_batch_dict
187171

188172
def forward(self, data_dict: dict, inference=False) -> dict:
189-
# get the probability
190173
prob = self.features(data_dict)
191-
# build the prediction dict for each output
192174
pred_dict = {'cls': prob, 'prob': prob, 'feat': prob}
193-
175+
194176
return pred_dict

0 commit comments

Comments
 (0)