Skip to content

Commit f1c4637

Browse files
committed
fix benchmark issues
1 parent d152328 commit f1c4637

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

deepvac/backbones/resnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def testFly(self):
267267
max_res = torch.max(softmaxs, dim=1)
268268
max_probability, max_index = max_res
269269
LOG.logI("path: {}, max_probability:{}, max_index:{}".format(path[0], max_probability.item(), max_index.item()))
270+
self.config.sample = input_tensor
270271

271272
def auditConfig():
272273
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -280,7 +281,7 @@ def auditConfig():
280281
config.core.log_every = 100
281282
config.core.num_workers = 4
282283

283-
config.core.net = ResNet50()
284+
config.core.net = resnet50(pretrained=True)
284285

285286
config.core.optimizer = optim.SGD(
286287
config.core.net.parameters(),
@@ -334,7 +335,7 @@ def auditConfig():
334335
if op == 'test':
335336
if(len(sys.argv) != 4):
336337
LOG.logE("Usage: python -m deepvac.backbones.resnet test <pretrained_model.pth> <your_test_img_input_dir>", exit=True)
337-
338+
#config.core.network_audit_disabled=True
338339
config.core.model_path = sys.argv[2]
339340
config.cast.ScriptCast.model_dir = "./script.pt"
340341
config.cast.ScriptCast.static_quantize_dir = "./static_quantize.pt"
@@ -348,11 +349,10 @@ def auditConfig():
348349
if(len(sys.argv) != 4):
349350
LOG.logE("Usage: python -m deepvac.backbones.resnet benchmark <pretrained_model.pth> <your_input_img.jpg>", exit=True)
350351

351-
config.core.model_reinterpret_cast = False
352-
config.core.cast_state_dict_strict = False
352+
# config.core.model_reinterpret_cast = False
353+
# config.core.cast_state_dict_strict = False
353354
# config.core.net_omit_keys = ['num_batches_tracked']
354355
# config.core.net_omit_keys_strict = False
355-
# config.core.network_audit_disabled=False
356356
config.core.model_path = sys.argv[2]
357357
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
358358
img_path = sys.argv[3]

deepvac/core/deepvac.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,10 @@ def process(self, input_tensor):
261261
LOG.logI("You did not provide input with config.core.sample...")
262262

263263
LOG.logI("testFly() is your last chance, you must have already reimplemented testFly() in subclass {}, right?".format(self.name()))
264-
return self.testFly()
264+
x = self.testFly()
265+
if self.config.sample is None:
266+
LOG.logE("You must set self.config.sample in testFly() reimplementation in your subclass {}".format(self.name()), exit=True)
267+
return x
265268

266269
def __call__(self, input_tensor=None):
267270
self.auditConfig()

0 commit comments

Comments
 (0)