Skip to content

Commit fa251e1

Browse files
authored
remove redundant inference
1 parent 0222760 commit fa251e1

File tree

1 file changed

+0
-23
lines changed

1 file changed

+0
-23
lines changed

training/detectors/altfreezing_detector.py

-23
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def __init__(self, config):
151151
self.resnet.load_state_dict(modified_weights, strict=True)
152152

153153
self.loss_func = nn.BCELoss() # The output of the model is a probability value between 0 and 1 (haved used sigmoid)
154-
self.prob, self.label = [], []
155-
self.correct, self.total = 0, 0
156154

157155
def build_backbone(self, config):
158156
pass
@@ -193,26 +191,5 @@ def forward(self, data_dict: dict, inference=False) -> dict:
193191
prob = self.features(data_dict)
194192
# build the prediction dict for each output
195193
pred_dict = {'cls': prob, 'prob': prob, 'feat': None}
196-
if inference:
197-
self.prob.extend(
198-
pred_dict['prob']
199-
.detach()
200-
.squeeze()
201-
.cpu()
202-
.numpy()
203-
)
204-
self.label.extend(
205-
data_dict['label']
206-
.detach()
207-
.squeeze()
208-
.cpu()
209-
.numpy()
210-
)
211-
# deal with acc
212-
prediction_class = (prob >= 0.5).type(torch.int).view(-1)
213-
assert prediction_class.shape == data_dict['label'].shape
214-
correct = (prediction_class == data_dict['label']).sum().item()
215-
self.correct += correct
216-
self.total += data_dict['label'].size(0)
217194

218195
return pred_dict

0 commit comments

Comments
 (0)