Skip to content

Commit ca52284

Browse files
authored
remove redundant inference
1 parent 8290244 commit ca52284

File tree

1 file changed

+0
-22
lines changed

1 file changed

+0
-22
lines changed

training/detectors/efficientnetb4_detector.py

-22
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def __init__(self, config):
5858
self.config = config
5959
self.backbone = self.build_backbone(config)
6060
self.loss_func = self.build_loss(config)
61-
self.prob, self.label = [], []
62-
self.correct, self.total = 0, 0
6361

6462
def build_backbone(self, config):
6563
# prepare the backbone
@@ -111,26 +109,6 @@ def forward(self, data_dict: dict, inference=False) -> dict:
111109
prob = torch.softmax(pred, dim=1)[:, 1]
112110
# build the prediction dict for each output
113111
pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
114-
if inference:
115-
self.prob.append(
116-
pred_dict['prob']
117-
.detach()
118-
.squeeze()
119-
.cpu()
120-
.numpy()
121-
)
122-
self.label.append(
123-
data_dict['label']
124-
.detach()
125-
.squeeze()
126-
.cpu()
127-
.numpy()
128-
)
129-
# deal with acc
130-
_, prediction_class = torch.max(pred, 1)
131-
correct = (prediction_class == data_dict['label']).sum().item()
132-
self.correct += correct
133-
self.total += data_dict['label'].size(0)
134112

135113
return pred_dict
136114

0 commit comments

Comments
 (0)