57
57
OUTPUT_DIR: .
58
58
"""
59
59
60
-
61
60
'''
62
61
# author: Zhiyuan Yan
63
62
85
84
}
86
85
'''
87
86
88
- import os
89
- import datetime
90
87
import logging
91
- import numpy as np
92
- from sklearn import metrics
93
- from typing import Union
94
- from collections import defaultdict
95
-
96
- import torch
97
- import torch .nn as nn
98
- import torch .nn .functional as F
99
- import torch .optim as optim
100
- from torch .nn import DataParallel
101
- from torch .utils .tensorboard import SummaryWriter
102
-
103
- from metrics .base_metrics_class import calculate_metrics_for_train
88
+ import os
89
+ import sys
104
90
105
- from .base_detector import AbstractDetector
106
91
from detectors import DETECTOR
107
- from networks import BACKBONE
108
92
from loss import LOSSFUNC
93
+ from metrics .base_metrics_class import calculate_metrics_for_train
109
94
95
+ from .base_detector import AbstractDetector
110
96
111
- import os
112
- import sys
113
97
current_file_path = os .path .abspath (__file__ )
114
98
parent_dir = os .path .dirname (os .path .dirname (current_file_path ))
115
99
project_root_dir = os .path .dirname (parent_dir )
120
104
from .utils .slowfast .models .video_model_builder import ResNet as ResNetOri
121
105
from .utils .slowfast .config .defaults import get_cfg
122
106
from torch import nn
123
- import random
124
-
125
107
126
108
random_select = True
127
109
no_time_pool = True
128
110
129
111
logger = logging .getLogger (__name__ )
130
112
113
+
131
114
@DETECTOR .register_module (module_name = 'i3d' )
132
115
class I3DDetector (AbstractDetector ):
133
116
def __init__ (self , config ):
@@ -137,7 +120,7 @@ def __init__(self, config):
137
120
cfg .NUM_GPUS = 1
138
121
cfg .TEST .BATCH_SIZE = 1
139
122
cfg .TRAIN .BATCH_SIZE = 1
140
- cfg .DATA .NUM_FRAMES = 16
123
+ cfg .DATA .NUM_FRAMES = config [ 'clip_size' ]
141
124
self .resnet = ResNetOri (cfg )
142
125
if config ['pretrained' ] is not None :
143
126
print (f"loading pretrained model from { config ['pretrained' ]} " )
@@ -150,45 +133,44 @@ def __init__(self, config):
150
133
self .resnet .load_state_dict (modified_weights , strict = True )
151
134
152
135
self .loss_func = nn .BCELoss () # The output of the model is a probability value between 0 and 1 (haved used sigmoid)
153
-
136
+
154
137
def build_backbone (self , config ):
155
138
pass
156
-
139
+
157
140
def build_loss (self , config ):
158
141
# prepare the loss function
159
142
loss_class = LOSSFUNC [config ['loss_func' ]]
160
143
loss_func = loss_class ()
161
144
return loss_func
162
-
145
+
163
146
def features (self , data_dict : dict ) -> torch .tensor :
164
- inputs = [data_dict ['image' ].permute (0 ,2 , 1 , 3 , 4 )]
147
+ inputs = [data_dict ['image' ].permute (0 , 2 , 1 , 3 , 4 )]
165
148
pred = self .resnet (inputs )
166
- output = {}
167
- output [ "final_output" ] = pred
149
+ output = {"final_output" : pred }
150
+
168
151
return output ["final_output" ]
169
152
170
153
def classifier (self , features : torch .tensor ):
171
154
pass
172
-
155
+
173
156
def get_losses (self , data_dict : dict , pred_dict : dict ) -> dict :
174
157
label = data_dict ['label' ].float ()
175
158
pred = pred_dict ['cls' ].view (- 1 )
176
159
loss = self .loss_func (pred , label )
177
160
loss_dict = {'overall' : loss }
161
+
178
162
return loss_dict
179
-
163
+
180
164
def get_train_metrics (self , data_dict : dict , pred_dict : dict ) -> dict :
181
165
label = data_dict ['label' ]
182
166
pred = pred_dict ['cls' ]
183
- # compute metrics for batch data
184
167
auc , eer , acc , ap = calculate_metrics_for_train (label .detach (), pred .detach ())
185
168
metric_batch_dict = {'acc' : acc , 'auc' : auc , 'eer' : eer , 'ap' : ap }
169
+
186
170
return metric_batch_dict
187
171
188
172
def forward (self , data_dict : dict , inference = False ) -> dict :
189
- # get the probability
190
173
prob = self .features (data_dict )
191
- # build the prediction dict for each output
192
174
pred_dict = {'cls' : prob , 'prob' : prob , 'feat' : prob }
193
-
175
+
194
176
return pred_dict
0 commit comments